-
Notifications
You must be signed in to change notification settings - Fork 327
feat: expose user-defined state in MultiAgent Graph #703
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
b82496c
0a8f464
1087bd6
caa9d1e
d081102
84cebea
b4314f5
a648268
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,8 @@ | |
Provides minimal foundation for multi-agent patterns (Swarm, Graph). | ||
""" | ||
|
||
import copy | ||
import json | ||
from abc import ABC, abstractmethod | ||
from dataclasses import dataclass, field | ||
from enum import Enum | ||
|
@@ -22,6 +24,105 @@ class Status(Enum): | |
FAILED = "failed" | ||
|
||
|
||
@dataclass | ||
class MultiAgentNode: | ||
"""Base class for nodes in multi-agent systems.""" | ||
|
||
node_id: str | ||
|
||
def __hash__(self) -> int: | ||
"""Return hash for MultiAgentNode based on node_id.""" | ||
return hash(self.node_id) | ||
|
||
def __eq__(self, other: Any) -> bool: | ||
"""Return equality for MultiAgentNode based on node_id.""" | ||
if not isinstance(other, MultiAgentNode): | ||
return False | ||
return self.node_id == other.node_id | ||
|
||
|
||
@dataclass | ||
class SharedContext: | ||
"""Shared context between multi-agent nodes. | ||
|
||
This class provides a key-value store for sharing information across nodes | ||
in multi-agent systems like Graph and Swarm. It validates that all values | ||
are JSON serializable to ensure compatibility. | ||
""" | ||
|
||
context: dict[str, dict[str, Any]] = field(default_factory=dict) | ||
|
||
def add_context(self, node: MultiAgentNode, key: str, value: Any) -> None: | ||
"""Add context for a specific node. | ||
dbschmigelski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Args: | ||
node: The node object to add context for | ||
key: The key to store the value under | ||
value: The value to store (must be JSON serializable) | ||
|
||
Raises: | ||
ValueError: If key is invalid or value is not JSON serializable | ||
""" | ||
self._validate_key(key) | ||
self._validate_json_serializable(value) | ||
|
||
if node.node_id not in self.context: | ||
self.context[node.node_id] = {} | ||
self.context[node.node_id][key] = value | ||
|
||
def get_context(self, node: MultiAgentNode, key: str | None = None) -> Any: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure I'm set on this approach because of backwards compatibility. If
|
||
"""Get context for a specific node. | ||
|
||
Args: | ||
node: The node object to get context for | ||
key: The specific key to retrieve (if None, returns all context for the node) | ||
|
||
Returns: | ||
The stored value, entire context dict for the node, or None if not found | ||
""" | ||
if node.node_id not in self.context: | ||
return None if key else {} | ||
|
||
if key is None: | ||
return copy.deepcopy(self.context[node.node_id]) | ||
else: | ||
value = self.context[node.node_id].get(key) | ||
return copy.deepcopy(value) if value is not None else None | ||
|
||
def _validate_key(self, key: str) -> None: | ||
"""Validate that a key is valid. | ||
|
||
Args: | ||
key: The key to validate | ||
|
||
Raises: | ||
ValueError: If key is invalid | ||
""" | ||
if key is None: | ||
raise ValueError("Key cannot be None") | ||
if not isinstance(key, str): | ||
raise ValueError("Key must be a string") | ||
if not key.strip(): | ||
raise ValueError("Key cannot be empty") | ||
|
||
def _validate_json_serializable(self, value: Any) -> None: | ||
"""Validate that a value is JSON serializable. | ||
|
||
Args: | ||
value: The value to validate | ||
|
||
Raises: | ||
ValueError: If value is not JSON serializable | ||
""" | ||
try: | ||
json.dumps(value) | ||
except (TypeError, ValueError) as e: | ||
raise ValueError( | ||
f"Value is not JSON serializable: {type(value).__name__}. " | ||
f"Only JSON-compatible types (str, int, float, bool, list, dict, None) are allowed." | ||
) from e | ||
|
||
|
||
@dataclass | ||
class NodeResult: | ||
"""Unified result from node execution - handles both Agent and nested MultiAgentBase results. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,7 +29,7 @@ | |
from ..telemetry import get_tracer | ||
from ..types.content import ContentBlock, Messages | ||
from ..types.event_loop import Metrics, Usage | ||
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status | ||
from .base import MultiAgentBase, MultiAgentNode, MultiAgentResult, NodeResult, SharedContext, Status | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -46,6 +46,7 @@ class GraphState: | |
task: The original input prompt/query provided to the graph execution. | ||
This represents the actual work to be performed by the graph as a whole. | ||
Entry point nodes receive this task as their input if they have no dependencies. | ||
shared_context: Context shared between graph nodes for storing user-defined state. | ||
""" | ||
|
||
# Task (with default empty string) | ||
|
@@ -61,6 +62,9 @@ class GraphState: | |
# Results | ||
results: dict[str, NodeResult] = field(default_factory=dict) | ||
|
||
# User-defined state shared across nodes | ||
shared_context: "SharedContext" = field(default_factory=lambda: SharedContext()) | ||
|
||
# Accumulated metrics | ||
accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0)) | ||
accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0)) | ||
|
@@ -126,7 +130,7 @@ def should_traverse(self, state: GraphState) -> bool: | |
|
||
|
||
@dataclass | ||
class GraphNode: | ||
class GraphNode(MultiAgentNode): | ||
"""Represents a node in the graph. | ||
|
||
The execution_status tracks the node's lifecycle within graph orchestration: | ||
|
@@ -135,7 +139,6 @@ class GraphNode: | |
- COMPLETED/FAILED: Node finished executing (regardless of result quality) | ||
""" | ||
|
||
node_id: str | ||
executor: Agent | MultiAgentBase | ||
dependencies: set["GraphNode"] = field(default_factory=set) | ||
execution_status: Status = Status.PENDING | ||
|
@@ -389,6 +392,25 @@ def __init__( | |
self.state = GraphState() | ||
self.tracer = get_tracer() | ||
|
||
@property | ||
def shared_context(self) -> SharedContext: | ||
"""Access to the shared context for storing user-defined state across graph nodes. | ||
|
||
Returns: | ||
The SharedContext instance that can be used to store and retrieve | ||
information that should be accessible to all nodes in the graph. | ||
|
||
Example: | ||
```python | ||
graph = Graph(...) | ||
node1 = graph.nodes["node1"] | ||
node2 = graph.nodes["node2"] | ||
graph.shared_context.add_context(node1, "file_reference", "/path/to/file") | ||
graph.shared_context.get_context(node2, "file_reference") | ||
``` | ||
""" | ||
return self.state.shared_context | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry for the delay I need to take a look at this again more deeply. What is strange to be is that we have different hierarchies for Swarm and Graph. For Swarm we have SharedContext at the same level as SwarmState. For Graph we have SharedContext nested within GraphState. Even if we kept it this way I'm not sure we need this convenience method to surface shared context at the GraphNode level. It feels strange to place the context on the state then surface it higher. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll come back tomorrow with a final decision on the path for this. |
||
|
||
def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> GraphResult: | ||
"""Invoke the graph synchronously.""" | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit we can remove the hash from SwarmNode since we can now inherit