Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion src/strands/experimental/hooks/multiagent/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@
is used—hooks read from the orchestrator directly.
"""

import uuid
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

from typing_extensions import override

from ....hooks import BaseHookEvent
from ....types.interrupt import _Interruptible

if TYPE_CHECKING:
from ....multiagent.base import MultiAgentBase
Expand All @@ -28,18 +32,37 @@ class MultiAgentInitializedEvent(BaseHookEvent):


@dataclass
class BeforeNodeCallEvent(BaseHookEvent):
class BeforeNodeCallEvent(BaseHookEvent, _Interruptible):
"""Event triggered before individual node execution starts.

Attributes:
source: The multi-agent orchestrator instance
node_id: ID of the node about to execute
invocation_state: Configuration that user passes in
cancel_node: A user defined message that when set, will cancel the node execution.
The message will be placed into the node result with an error status. If set to `True`, Strands will cancel
the node and use a default cancel message.
"""

source: "MultiAgentBase"
node_id: str
invocation_state: dict[str, Any] | None = None
cancel_node: bool | str = False

def _can_write(self, name: str) -> bool:
return name in ["cancel_node"]

@override
def _interrupt_id(self, name: str) -> str:
"""Unique id for the interrupt.

Args:
name: User defined name for the interrupt.

Returns:
Interrupt id.
"""
return f"v1:before_node_call:{self.node_id}:{uuid.uuid5(uuid.NAMESPACE_OID, name)}"


@dataclass
Expand Down
2 changes: 2 additions & 0 deletions src/strands/interrupt.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def resume(self, prompt: "AgentInput") -> None:

self.interrupts[interrupt_id].response = interrupt_response

self.context["responses"] = contents

def to_dict(self) -> dict[str, Any]:
"""Serialize to dict for session management."""
return asdict(self)
Expand Down
40 changes: 27 additions & 13 deletions src/strands/multiagent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,34 @@

from .._async import run_async
from ..agent import AgentResult
from ..interrupt import Interrupt
from ..types.event_loop import Metrics, Usage
from ..types.multiagent import MultiAgentInput

logger = logging.getLogger(__name__)


class Status(Enum):
"""Execution status for both graphs and nodes."""
"""Execution status for both graphs and nodes.

Attributes:
PENDING: Task has not started execution yet.
EXECUTING: Task is currently running.
COMPLETED: Task finished successfully.
FAILED: Task encountered an error and could not complete.
INTERRUPTED: Task was interrupted by user.
"""

PENDING = "pending"
EXECUTING = "executing"
COMPLETED = "completed"
FAILED = "failed"
INTERRUPTED = "interrupted"


@dataclass
class NodeResult:
"""Unified result from node execution - handles both Agent and nested MultiAgentBase results.

The status field represents the semantic outcome of the node's work:
- COMPLETED: The node's task was successfully accomplished
- FAILED: The node's task failed or produced an error
"""
"""Unified result from node execution - handles both Agent and nested MultiAgentBase results."""

# Core result data - single AgentResult, nested MultiAgentResult, or Exception
result: Union[AgentResult, "MultiAgentResult", Exception]
Expand All @@ -47,6 +52,7 @@ class NodeResult:
accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0))
accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0))
execution_count: int = 0
interrupts: list[Interrupt] = field(default_factory=list)

def get_agent_results(self) -> list[AgentResult]:
"""Get all AgentResult objects from this node, flattened if nested."""
Expand Down Expand Up @@ -78,6 +84,7 @@ def to_dict(self) -> dict[str, Any]:
"accumulated_usage": self.accumulated_usage,
"accumulated_metrics": self.accumulated_metrics,
"execution_count": self.execution_count,
"interrupts": [interrupt.to_dict() for interrupt in self.interrupts],
}

@classmethod
Expand All @@ -100,31 +107,32 @@ def from_dict(cls, data: dict[str, Any]) -> "NodeResult":
usage = _parse_usage(data.get("accumulated_usage", {}))
metrics = _parse_metrics(data.get("accumulated_metrics", {}))

interrupts = []
for interrupt_data in data.get("interrupts", []):
interrupts.append(Interrupt(**interrupt_data))

return cls(
result=result,
execution_time=int(data.get("execution_time", 0)),
status=Status(data.get("status", "pending")),
accumulated_usage=usage,
accumulated_metrics=metrics,
execution_count=int(data.get("execution_count", 0)),
interrupts=interrupts,
)


@dataclass
class MultiAgentResult:
"""Result from multi-agent execution with accumulated metrics.

The status field represents the outcome of the MultiAgentBase execution:
- COMPLETED: The execution was successfully accomplished
- FAILED: The execution failed or produced an error
"""
"""Result from multi-agent execution with accumulated metrics."""

status: Status = Status.PENDING
results: dict[str, NodeResult] = field(default_factory=lambda: {})
accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0))
accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0))
execution_count: int = 0
execution_time: int = 0
interrupts: list[Interrupt] = field(default_factory=list)

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult":
Expand All @@ -136,13 +144,18 @@ def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult":
usage = _parse_usage(data.get("accumulated_usage", {}))
metrics = _parse_metrics(data.get("accumulated_metrics", {}))

interrupts = []
for interrupt_data in data.get("interrupts", []):
interrupts.append(Interrupt(**interrupt_data))

multiagent_result = cls(
status=Status(data["status"]),
results=results,
accumulated_usage=usage,
accumulated_metrics=metrics,
execution_count=int(data.get("execution_count", 0)),
execution_time=int(data.get("execution_time", 0)),
interrupts=interrupts,
)
return multiagent_result

Expand All @@ -156,6 +169,7 @@ def to_dict(self) -> dict[str, Any]:
"accumulated_metrics": self.accumulated_metrics,
"execution_count": self.execution_count,
"execution_time": self.execution_time,
"interrupts": [interrupt.to_dict() for interrupt in self.interrupts],
}


Expand Down
4 changes: 2 additions & 2 deletions src/strands/multiagent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,7 +964,7 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]:
if isinstance(self.state.task, str):
return [ContentBlock(text=self.state.task)]
else:
return self.state.task
return cast(list[ContentBlock], self.state.task)

# Combine task with dependency outputs
node_input = []
Expand All @@ -975,7 +975,7 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]:
else:
# Add task content blocks with a prefix
node_input.append(ContentBlock(text="Original Task:"))
node_input.extend(self.state.task)
node_input.extend(cast(list[ContentBlock], self.state.task))

# Add dependency outputs
node_input.append(ContentBlock(text="\nInputs from previous nodes:"))
Expand Down
Loading
Loading