Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/strands/interrupt.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def resume(self, prompt: "AgentInput") -> None:

self.interrupts[interrupt_id].response = interrupt_response

self.context["responses"] = contents

def to_dict(self) -> dict[str, Any]:
"""Serialize to dict for session management."""
return asdict(self)
Expand Down
17 changes: 10 additions & 7 deletions src/strands/types/interrupt.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,19 +71,14 @@ def approve(self, event: BeforeToolCallEvent) -> None:
- Interrupts are session managed in-between return and user response.
"""

from typing import TYPE_CHECKING, Any, Protocol, TypedDict
from typing import Any, Protocol, TypedDict

from ..interrupt import Interrupt, InterruptException

if TYPE_CHECKING:
from ..agent import Agent


class _Interruptible(Protocol):
"""Interface that adds interrupt support to hook events and tools."""

agent: "Agent"

def interrupt(self, name: str, reason: Any = None, response: Any = None) -> Any:
"""Trigger the interrupt with a reason.

Expand All @@ -97,9 +92,17 @@ def interrupt(self, name: str, reason: Any = None, response: Any = None) -> Any:

Raises:
InterruptException: If human input is required.
RuntimeError: If agent instance attribute not set.
"""
for attr_name in ["agent", "source"]:
if hasattr(self, attr_name):
agent = getattr(self, attr_name)
break
else:
raise RuntimeError("agent instance attribute not set")

id = self._interrupt_id(name)
state = self.agent._interrupt_state
state = agent._interrupt_state

interrupt_ = state.interrupts.setdefault(id, Interrupt(id, name, reason, response))
if interrupt_.response:
Expand Down
4 changes: 4 additions & 0 deletions tests/strands/test_interrupt.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ def test_interrupt_state_resume():
exp_response = "test response"
assert tru_response == exp_response

tru_context = interrupt_state.context
exp_context = {"responses": prompt}
assert tru_context == exp_context


def test_interrupt_state_resumse_deactivated():
interrupt_state = _InterruptState(activated=False)
Expand Down
9 changes: 9 additions & 0 deletions tests/strands/types/test_interrupt.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,12 @@ def test_interrupt_hook_event_interrupt_response_empty(interrupt, agent, interru

with pytest.raises(InterruptException):
interrupt_hook_event.interrupt("test_name")


def test_interrupt_hook_event_interrupt_missing_agent():
class Event(_Interruptible):
pass

event = Event()
with pytest.raises(RuntimeError, match="agent instance attribute not set"):
event.interrupt("test_name")
Loading