From 04e7399f0fd7b53636e4cd3bc8d83971fddfc9c3 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Tue, 18 Nov 2025 09:17:37 -0500 Subject: [PATCH] interrupt - activate - set context separately --- src/strands/event_loop/event_loop.py | 3 ++- src/strands/interrupt.py | 9 ++------- tests/strands/agent/test_agent.py | 3 ++- tests/strands/event_loop/test_event_loop.py | 3 ++- tests/strands/test_interrupt.py | 7 +------ 5 files changed, 9 insertions(+), 16 deletions(-) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 562de24b8..90776eaf2 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -483,7 +483,8 @@ async def _handle_tool_execution( if interrupts: # Session state stored on AfterInvocationEvent. - agent._interrupt_state.activate(context={"tool_use_message": message, "tool_results": tool_results}) + agent._interrupt_state.context = {"tool_use_message": message, "tool_results": tool_results} + agent._interrupt_state.activate() agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) yield EventLoopStopEvent( diff --git a/src/strands/interrupt.py b/src/strands/interrupt.py index 919927e1a..da89d772b 100644 --- a/src/strands/interrupt.py +++ b/src/strands/interrupt.py @@ -53,13 +53,8 @@ class _InterruptState: context: dict[str, Any] = field(default_factory=dict) activated: bool = False - def activate(self, context: dict[str, Any] | None = None) -> None: - """Activate the interrupt state. - - Args: - context: Context associated with the interrupt event. - """ - self.context = context or {} + def activate(self) -> None: + """Activate the interrupt state.""" self.activated = True def deactivate(self) -> None: diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index d04f57948..76aeadeff 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -2001,8 +2001,9 @@ def test_agent__call__resume_interrupt(mock_model, tool_decorated, agenerator): reason="test reason", ) - agent._interrupt_state.activate(context={"tool_use_message": tool_use_message, "tool_results": []}) + agent._interrupt_state.context = {"tool_use_message": tool_use_message, "tool_results": []} agent._interrupt_state.interrupts[interrupt.id] = interrupt + agent._interrupt_state.activate() interrupt_response = {} diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 9335f91a8..e51680f6f 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -964,8 +964,9 @@ async def test_event_loop_cycle_interrupt_resume(agent, model, tool, tool_times_ }, ] - agent._interrupt_state.activate(context={"tool_use_message": tool_use_message, "tool_results": tool_results}) + agent._interrupt_state.context = {"tool_use_message": tool_use_message, "tool_results": tool_results} agent._interrupt_state.interrupts[interrupt.id] = interrupt + agent._interrupt_state.activate() interrupt_response = {} diff --git a/tests/strands/test_interrupt.py b/tests/strands/test_interrupt.py index a45d524e4..d9079b01a 100644 --- a/tests/strands/test_interrupt.py +++ b/tests/strands/test_interrupt.py @@ -27,14 +27,9 @@ def test_interrupt_to_dict(interrupt): def test_interrupt_state_activate(): interrupt_state = _InterruptState() - interrupt_state.activate(context={"test": "context"}) - + interrupt_state.activate() assert interrupt_state.activated - tru_context = interrupt_state.context - exp_context = {"test": "context"} - assert tru_context == exp_context - def test_interrupt_state_deactivate(): interrupt_state = _InterruptState(context={"test": "context"}, activated=True)