From 14bbab978949cb1dd685a1db6b0692ec75829a57 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Wed, 22 Apr 2026 16:19:52 -0400 Subject: [PATCH 1/2] feat(checkpoint): wire checkpointing into agent event loop --- src/strands/agent/agent.py | 51 ++++ src/strands/agent/agent_result.py | 18 +- src/strands/event_loop/event_loop.py | 58 ++++ .../experimental/checkpoint/__init__.py | 9 +- .../experimental/checkpoint/checkpoint.py | 14 +- src/strands/experimental/checkpoint/types.py | 37 +++ src/strands/types/_events.py | 7 +- src/strands/types/exceptions.py | 6 + tests/strands/agent/test_agent.py | 70 ++++- tests/strands/agent/test_agent_result.py | 84 ++++++ tests/strands/event_loop/test_event_loop.py | 117 ++++++++- .../test_event_loop_structured_output.py | 3 + .../checkpoint/test_checkpoint.py | 248 ++++++++++++++---- .../experimental/checkpoint/test_types.py | 17 ++ tests/strands/types/test__events.py | 37 ++- 15 files changed, 711 insertions(+), 65 deletions(-) create mode 100644 src/strands/experimental/checkpoint/types.py create mode 100644 tests/strands/experimental/checkpoint/test_types.py diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index e8ea3c9bc..e3ea9e860 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -29,6 +29,7 @@ from .._async import run_async from ..event_loop._retry import ModelRetryStrategy from ..event_loop.event_loop import INITIAL_DELAY, MAX_ATTEMPTS, MAX_DELAY, event_loop_cycle +from ..experimental.checkpoint import Checkpoint from ..tools._tool_helpers import generate_missing_tool_result_content from ..types._snapshot import ( SNAPSHOT_SCHEMA_VERSION, @@ -146,6 +147,7 @@ def __init__( tool_executor: ToolExecutor | None = None, retry_strategy: ModelRetryStrategy | _DefaultRetryStrategySentinel | None = _DEFAULT_RETRY_STRATEGY, concurrent_invocation_mode: ConcurrentInvocationMode = ConcurrentInvocationMode.THROW, + checkpointing: bool = False, ): """Initialize the Agent with the specified configuration. @@ -214,6 +216,11 @@ def __init__( Set to "unsafe_reentrant" to skip lock acquisition entirely, allowing concurrent invocations. Warning: "unsafe_reentrant" makes no guarantees about resulting behavior and is provided only for advanced use cases where the caller understands the risks. + checkpointing: When True, the event loop pauses at cycle boundaries (after model call, + after all tools execute) and returns an AgentResult with stop_reason="checkpoint" + and a populated ``checkpoint`` field. Persist the checkpoint and resume by passing a + ``CheckpointResumeContent`` block as the next prompt. Defaults to False. + See :mod:`strands.experimental.checkpoint` for usage and limitations. Raises: ValueError: If agent id contains path separators. @@ -304,6 +311,10 @@ def __init__( self._interrupt_state = _InterruptState() + # Checkpointing: when True, event loop pauses at cycle boundaries + self._checkpointing: bool = checkpointing + self._checkpoint_resume_context: Checkpoint | None = None + # Runtime state for model providers (e.g., server-side response ids) self._model_state: dict[str, Any] = {} @@ -998,6 +1009,46 @@ async def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: if self._interrupt_state.activated: return [] + # Resume detection — must run before existing shape handling so checkpointResume + # blocks aren't misinterpreted as content blocks. Mirrors _InterruptState.resume() + # conventions (TypeError for shape, KeyError for lookup, ValueError for misconfig; + # error messages use the SDK's key= | message format). + if isinstance(prompt, list) and prompt: + has_checkpoint_resume = any(isinstance(c, dict) and "checkpointResume" in c for c in prompt) + if has_checkpoint_resume: + if not self._checkpointing: + raise ValueError( + "Received checkpointResume block but agent was created with " + "checkpointing=False. Pass checkpointing=True when constructing " + "the Agent to enable durable execution." + ) + + invalid_types = [ + key + for content in prompt + if isinstance(content, dict) + for key in content + if key != "checkpointResume" + ] + if invalid_types: + raise TypeError( + f"content_types=<{invalid_types}> | checkpointResume cannot be mixed with other content types" + ) + + if len(prompt) != 1: + raise TypeError( + f"block_count=<{len(prompt)}> | only one checkpointResume block permitted per prompt" + ) + + resume_block = prompt[0].get("checkpointResume", {}) + if not isinstance(resume_block, dict) or "checkpoint" not in resume_block: + raise KeyError("checkpoint | missing required key in checkpointResume block") + + checkpoint = Checkpoint.from_dict(resume_block["checkpoint"]) + self.load_snapshot(Snapshot.from_dict(checkpoint.snapshot)) + self._checkpoint_resume_context = checkpoint + return [] + messages: Messages | None = None if prompt is not None: # Check if the latest message is toolUse diff --git a/src/strands/agent/agent_result.py b/src/strands/agent/agent_result.py index f0a399f81..9d077d803 100644 --- a/src/strands/agent/agent_result.py +++ b/src/strands/agent/agent_result.py @@ -9,6 +9,7 @@ from pydantic import BaseModel +from ..experimental.checkpoint import Checkpoint from ..interrupt import Interrupt from ..telemetry.metrics import EventLoopMetrics from ..types.content import Message @@ -26,6 +27,9 @@ class AgentResult: state: Additional state information from the event loop. interrupts: List of interrupts if raised by user. structured_output: Parsed structured output when structured_output_model was specified. + checkpoint: Checkpoint captured when the agent paused for durable execution. + Populated only when stop_reason == "checkpoint". See + strands.experimental.checkpoint for usage. """ stop_reason: StopReason @@ -34,6 +38,7 @@ class AgentResult: state: Any interrupts: Sequence[Interrupt] | None = None structured_output: BaseModel | None = None + checkpoint: Checkpoint | None = None @property def context_size(self) -> int | None: @@ -85,15 +90,23 @@ def from_dict(cls, data: dict[str, Any]) -> "AgentResult": Returns: AgentResult instance Raises: - TypeError: If the data format is invalid@ + TypeError: If the data format is invalid """ if data.get("type") != "agent_result": raise TypeError(f"AgentResult.from_dict: unexpected type {data.get('type')!r}") message = cast(Message, data.get("message")) stop_reason = cast(StopReason, data.get("stop_reason")) + checkpoint_data = data.get("checkpoint") + checkpoint = Checkpoint.from_dict(checkpoint_data) if checkpoint_data else None - return cls(message=message, stop_reason=stop_reason, metrics=EventLoopMetrics(), state={}) + return cls( + message=message, + stop_reason=stop_reason, + metrics=EventLoopMetrics(), + state={}, + checkpoint=checkpoint, + ) def to_dict(self) -> dict[str, Any]: """Convert this AgentResult to JSON-serializable dictionary. @@ -105,4 +118,5 @@ def to_dict(self) -> dict[str, Any]: "type": "agent_result", "message": self.message, "stop_reason": self.stop_reason, + "checkpoint": self.checkpoint.to_dict() if self.checkpoint else None, } diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index bf1cc7a84..8735648b8 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -15,6 +15,7 @@ from opentelemetry import trace as trace_api +from ..experimental.checkpoint import Checkpoint from ..hooks import AfterModelCallEvent, BeforeModelCallEvent, MessageAddedEvent from ..telemetry.metrics import Trace from ..telemetry.tracer import Tracer, get_tracer @@ -122,6 +123,19 @@ async def event_loop_cycle( # Initialize state and get cycle trace if "request_state" not in invocation_state: invocation_state["request_state"] = {} + + # Consume checkpoint resume context (one-shot: cleared after reading). + # Cross-invocation state (resume context) lives on the agent; within-cycle + # transient state (resume position for the skip check, cycle index) lives + # in invocation_state. + resume_ctx = agent._checkpoint_resume_context + if resume_ctx is not None: + agent._checkpoint_resume_context = None + # after_tools completed that cycle, so next cycle starts at +1 + next_cycle = resume_ctx.cycle_index + 1 if resume_ctx.position == "after_tools" else resume_ctx.cycle_index + invocation_state["_checkpoint_cycle_index"] = next_cycle + invocation_state["_checkpoint_resume_position"] = resume_ctx.position + attributes = {"event_loop_cycle_id": str(invocation_state.get("event_loop_cycle_id"))} cycle_start_time, cycle_trace = agent.event_loop_metrics.start_cycle(attributes=attributes) invocation_state["event_loop_cycle_trace"] = cycle_trace @@ -181,6 +195,32 @@ async def event_loop_cycle( ) if stop_reason == "tool_use": + # Checkpoint after model call, before tool execution. + # One-shot pop: safe because after_model always returns before reaching + # after_tools, so the stashed position is only consumed once. + if agent._checkpointing: + resume_pos = invocation_state.pop("_checkpoint_resume_position", None) + if resume_pos == "after_model": + pass # Just resumed here — skip re-checkpoint, proceed to tools + else: + cycle_index = invocation_state.get("_checkpoint_cycle_index", 0) + checkpoint = Checkpoint( + position="after_model", + cycle_index=cycle_index, + snapshot=agent.take_snapshot(preset="session").to_dict(), + ) + agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) + if cycle_span: + tracer.end_event_loop_cycle_span(span=cycle_span, message=message) + yield EventLoopStopEvent( + "checkpoint", + message, + agent.event_loop_metrics, + invocation_state["request_state"], + checkpoint=checkpoint, + ) + return + # Handle tool execution tool_events = _handle_tool_execution( stop_reason, @@ -590,6 +630,24 @@ async def _handle_tool_execution( ) return + # Checkpoint after all tools complete, before the next model call. + if agent._checkpointing: + cycle_index = invocation_state.get("_checkpoint_cycle_index", 0) + invocation_state["_checkpoint_cycle_index"] = cycle_index + 1 + checkpoint = Checkpoint( + position="after_tools", + cycle_index=cycle_index, + snapshot=agent.take_snapshot(preset="session").to_dict(), + ) + yield EventLoopStopEvent( + "checkpoint", + message, + agent.event_loop_metrics, + invocation_state["request_state"], + checkpoint=checkpoint, + ) + return + events = recurse_event_loop( agent=agent, invocation_state=invocation_state, structured_output_context=structured_output_context ) diff --git a/src/strands/experimental/checkpoint/__init__.py b/src/strands/experimental/checkpoint/__init__.py index 848cda6d6..bbb7c4cb4 100644 --- a/src/strands/experimental/checkpoint/__init__.py +++ b/src/strands/experimental/checkpoint/__init__.py @@ -8,5 +8,12 @@ """ from .checkpoint import CHECKPOINT_SCHEMA_VERSION, Checkpoint, CheckpointPosition +from .types import CheckpointResumeContent, CheckpointResumeDict -__all__ = ["CHECKPOINT_SCHEMA_VERSION", "Checkpoint", "CheckpointPosition"] +__all__ = [ + "CHECKPOINT_SCHEMA_VERSION", + "Checkpoint", + "CheckpointPosition", + "CheckpointResumeContent", + "CheckpointResumeDict", +] diff --git a/src/strands/experimental/checkpoint/checkpoint.py b/src/strands/experimental/checkpoint/checkpoint.py index f37e403c9..bb0a002b7 100644 --- a/src/strands/experimental/checkpoint/checkpoint.py +++ b/src/strands/experimental/checkpoint/checkpoint.py @@ -17,6 +17,14 @@ - State via AgentResult.checkpoint field - Resume via checkpointResume content block in next agent() call +Interaction with interrupts: +- Interrupts take priority over checkpoints. If a tool raises an Interrupt + during a checkpointing=True cycle, the event loop returns + stop_reason="interrupt" (not "checkpoint"). The after_tools checkpoint + is never reached because the interrupt path returns early. +- This is intentional: interrupts require human input, checkpoints are + for worker-level durability. Different semantics, different priorities. + V0 Known Limitations: - Metrics reset on each resume call. The caller is responsible for aggregating metrics across a durable run. EventLoopMetrics reflects only the current call. @@ -35,6 +43,8 @@ from dataclasses import asdict, dataclass, field from typing import Any, Literal +from ...types.exceptions import CheckpointException + logger = logging.getLogger(__name__) CHECKPOINT_SCHEMA_VERSION = "1.0" @@ -79,11 +89,11 @@ def from_dict(cls, data: dict[str, Any]) -> "Checkpoint": data: Serialized checkpoint data. Raises: - ValueError: If schema_version doesn't match the current version. + CheckpointException: If schema_version doesn't match the current version. """ version = data.get("schema_version", "") if version != CHECKPOINT_SCHEMA_VERSION: - raise ValueError( + raise CheckpointException( f"Checkpoints with schema version {version!r} are not compatible " f"with current version {CHECKPOINT_SCHEMA_VERSION}." ) diff --git a/src/strands/experimental/checkpoint/types.py b/src/strands/experimental/checkpoint/types.py new file mode 100644 index 000000000..67b43f465 --- /dev/null +++ b/src/strands/experimental/checkpoint/types.py @@ -0,0 +1,37 @@ +"""Content-block types for checkpoint resume. + +Mirrors the interrupt pattern (`InterruptResponseContent` in `types/interrupt.py`). +Stays under `experimental/checkpoint/` for V0; will graduate to +`src/strands/types/checkpoint.py` when the feature exits experimental. +""" + +from typing import Any, TypedDict + + +class CheckpointResumeDict(TypedDict): + """Inner payload for a checkpointResume content block. + + Attributes: + checkpoint: Serialized Checkpoint as produced by ``Checkpoint.to_dict()``. + """ + + checkpoint: dict[str, Any] + + +class CheckpointResumeContent(TypedDict): + """Content block that resumes a paused durable agent. + + Pass a list containing exactly one instance of this type as the prompt to + ``Agent.invoke_async`` / ``Agent.__call__`` to resume from a checkpoint. + + Example:: + + result = await agent.invoke_async( + [{"checkpointResume": {"checkpoint": previous_checkpoint.to_dict()}}] + ) + + Attributes: + checkpointResume: The resume payload carrying the serialized checkpoint. + """ + + checkpointResume: CheckpointResumeDict diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index 1d5a5de79..01f2c9e3e 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -22,6 +22,7 @@ if TYPE_CHECKING: from ..agent import AgentResult from ..agent._agent_as_tool import _AgentAsTool + from ..experimental.checkpoint import Checkpoint from ..multiagent.base import MultiAgentResult, NodeResult @@ -227,6 +228,7 @@ def __init__( request_state: Any, interrupts: Sequence[Interrupt] | None = None, structured_output: BaseModel | None = None, + checkpoint: "Checkpoint | None" = None, ) -> None: """Initialize with the final execution results. @@ -237,8 +239,11 @@ def __init__( request_state: Final state of the agent execution interrupts: Interrupts raised by user during agent execution. structured_output: Optional structured output result + checkpoint: Optional checkpoint when stop_reason == "checkpoint". """ - super().__init__({"stop": (stop_reason, message, metrics, request_state, interrupts, structured_output)}) + super().__init__( + {"stop": (stop_reason, message, metrics, request_state, interrupts, structured_output, checkpoint)} + ) @property @override diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 5db80a26e..4f766a77e 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -111,3 +111,9 @@ class ConcurrencyException(Exception): """ pass + + +class CheckpointException(Exception): + """Exception raised when checkpoint operations fail (e.g., incompatible schema version).""" + + pass diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 1e27274a1..e31106cd5 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -29,7 +29,12 @@ from strands.types._events import EventLoopStopEvent, ModelStreamEvent from strands.types.agent import ConcurrentInvocationMode from strands.types.content import Messages -from strands.types.exceptions import ConcurrencyException, ContextWindowOverflowException, EventLoopException +from strands.types.exceptions import ( + CheckpointException, + ConcurrencyException, + ContextWindowOverflowException, + EventLoopException, +) from strands.types.session import Session, SessionAgent, SessionMessage, SessionType from tests.fixtures.mock_session_repository import MockedSessionRepository from tests.fixtures.mocked_model_provider import MockedModelProvider @@ -2773,3 +2778,66 @@ def test_as_tool_defaults_description_when_agent_has_none(): tool = agent.as_tool() assert tool.tool_spec["description"] == "Use the researcher agent as a tool by providing a natural language input" + + +# ========================================================================= +# Checkpoint integration tests (Part B) +# ========================================================================= + + +def test_agent_checkpointing_defaults_to_false() -> None: + agent = Agent() + assert agent._checkpointing is False + assert agent._checkpoint_resume_context is None + + +def test_agent_checkpointing_flag_stored() -> None: + agent = Agent(checkpointing=True) + assert agent._checkpointing is True + assert agent._checkpoint_resume_context is None + + +@pytest.mark.asyncio +async def test_checkpoint_resume_without_checkpointing_flag_raises_value_error() -> None: + agent = Agent(checkpointing=False) + prompt = [{"checkpointResume": {"checkpoint": {}}}] + with pytest.raises(ValueError, match="checkpointing=True"): + await agent.invoke_async(prompt) + + +@pytest.mark.asyncio +async def test_checkpoint_resume_mixed_content_raises_type_error() -> None: + agent = Agent(checkpointing=True) + prompt = [ + {"checkpointResume": {"checkpoint": {}}}, + {"text": "bogus"}, + ] + with pytest.raises(TypeError, match="content_types"): + await agent.invoke_async(prompt) + + +@pytest.mark.asyncio +async def test_checkpoint_resume_multiple_blocks_raises_type_error() -> None: + agent = Agent(checkpointing=True) + prompt = [ + {"checkpointResume": {"checkpoint": {}}}, + {"checkpointResume": {"checkpoint": {}}}, + ] + with pytest.raises(TypeError, match="block_count"): + await agent.invoke_async(prompt) + + +@pytest.mark.asyncio +async def test_checkpoint_resume_missing_checkpoint_key_raises_key_error() -> None: + agent = Agent(checkpointing=True) + prompt = [{"checkpointResume": {}}] + with pytest.raises(KeyError, match="checkpoint"): + await agent.invoke_async(prompt) + + +@pytest.mark.asyncio +async def test_checkpoint_resume_schema_mismatch_raises_checkpoint_exception() -> None: + agent = Agent(checkpointing=True) + prompt = [{"checkpointResume": {"checkpoint": {"schema_version": "0.1", "position": "after_model"}}}] + with pytest.raises(CheckpointException, match="schema version"): + await agent.invoke_async(prompt) diff --git a/tests/strands/agent/test_agent_result.py b/tests/strands/agent/test_agent_result.py index 64391f299..3ea15d016 100644 --- a/tests/strands/agent/test_agent_result.py +++ b/tests/strands/agent/test_agent_result.py @@ -5,6 +5,7 @@ from pydantic import BaseModel from strands.agent.agent_result import AgentResult +from strands.experimental.checkpoint import Checkpoint from strands.interrupt import Interrupt from strands.telemetry.metrics import EventLoopMetrics from strands.types.content import Message @@ -110,6 +111,7 @@ def test_to_dict(mock_metrics, simple_message: Message): "type": "agent_result", "message": simple_message, "stop_reason": "end_turn", + "checkpoint": None, } @@ -384,3 +386,85 @@ def test_context_size_none_when_no_data(mock_metrics, simple_message: Message): mock_metrics.latest_context_size = None result = AgentResult(stop_reason="end_turn", message=simple_message, metrics=mock_metrics, state={}) assert result.context_size is None + + +# ========================================================================= +# Checkpoint field and round-trip serialization (Part B) +# +# Covers the V0 durable-execution contract: when stop_reason == "checkpoint", +# AgentResult carries a Checkpoint that round-trips through to_dict/from_dict. +# ========================================================================= + + +def test_agent_result_checkpoint_field_default_none() -> None: + result = AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "hi"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + assert result.checkpoint is None + + +def test_agent_result_accepts_checkpoint() -> None: + checkpoint = Checkpoint(position="after_model", cycle_index=0) + result = AgentResult( + stop_reason="checkpoint", + message={"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "name": "t", "input": {}}}]}, + metrics=EventLoopMetrics(), + state={}, + checkpoint=checkpoint, + ) + assert result.checkpoint is checkpoint + assert result.checkpoint.position == "after_model" + + +def test_agent_result_to_dict_includes_checkpoint() -> None: + checkpoint = Checkpoint(position="after_model", cycle_index=0) + result = AgentResult( + stop_reason="checkpoint", + message={"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "name": "t", "input": {}}}]}, + metrics=EventLoopMetrics(), + state={}, + checkpoint=checkpoint, + ) + d = result.to_dict() + assert d["checkpoint"] is not None + assert d["checkpoint"]["position"] == "after_model" + assert d["checkpoint"]["cycle_index"] == 0 + + +def test_agent_result_to_dict_checkpoint_none_when_absent() -> None: + result = AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "hi"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + d = result.to_dict() + assert d["checkpoint"] is None + + +def test_agent_result_from_dict_round_trips_checkpoint() -> None: + original = AgentResult( + stop_reason="checkpoint", + message={"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "name": "t", "input": {}}}]}, + metrics=EventLoopMetrics(), + state={}, + checkpoint=Checkpoint(position="after_tools", cycle_index=3), + ) + restored = AgentResult.from_dict(original.to_dict()) + assert restored.checkpoint is not None + assert restored.checkpoint.position == "after_tools" + assert restored.checkpoint.cycle_index == 3 + + +def test_agent_result_from_dict_handles_missing_checkpoint() -> None: + restored = AgentResult.from_dict( + { + "type": "agent_result", + "message": {"role": "assistant", "content": [{"text": "done"}]}, + "stop_reason": "end_turn", + } + ) + assert restored.checkpoint is None diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 871371f5f..5167f5a2d 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -13,6 +13,7 @@ import strands.telemetry from strands import Agent from strands.event_loop._retry import ModelRetryStrategy +from strands.experimental.checkpoint import Checkpoint from strands.hooks import ( AfterModelCallEvent, BeforeModelCallEvent, @@ -157,6 +158,8 @@ def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_regis mock._interrupt_state = _InterruptState() mock._cancel_signal = threading.Event() mock._model_state = {} + mock._checkpointing = False + mock._checkpoint_resume_context = None mock.trace_attributes = {} mock.retry_strategy = ModelRetryStrategy() @@ -190,7 +193,7 @@ async def test_event_loop_cycle_text_response( invocation_state={}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state, _, _ = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _, _, _ = events[-1]["stop"] exp_stop_reason = "end_turn" exp_message = {"role": "assistant", "content": [{"text": "test text"}], "metadata": ANY} @@ -222,7 +225,7 @@ async def test_event_loop_cycle_text_response_throttling( invocation_state={}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state, _, _ = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _, _, _ = events[-1]["stop"] exp_stop_reason = "end_turn" exp_message = {"role": "assistant", "content": [{"text": "test text"}], "metadata": ANY} @@ -260,7 +263,7 @@ async def test_event_loop_cycle_exponential_backoff( invocation_state={}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state, _, _ = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _, _, _ = events[-1]["stop"] # Verify the final response assert tru_stop_reason == "end_turn" @@ -351,7 +354,7 @@ async def test_event_loop_cycle_tool_result( invocation_state={}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state, _, _ = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _, _, _ = events[-1]["stop"] exp_stop_reason = "end_turn" exp_message = {"role": "assistant", "content": [{"text": "test text"}], "metadata": ANY} @@ -469,7 +472,7 @@ async def test_event_loop_cycle_stop( invocation_state={"request_state": {"stop_event_loop": True}}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state, _, _ = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _, _, _ = events[-1]["stop"] exp_stop_reason = "tool_use" exp_message = { @@ -833,7 +836,7 @@ async def test_request_state_initialization(alist): invocation_state={}, ) events = await alist(stream) - _, _, _, tru_request_state, _, _ = events[-1]["stop"] + _, _, _, tru_request_state, _, _, _ = events[-1]["stop"] # Verify request_state was initialized to empty dict assert tru_request_state == {} @@ -845,7 +848,7 @@ async def test_request_state_initialization(alist): invocation_state={"request_state": initial_request_state}, ) events = await alist(stream) - _, _, _, tru_request_state, _, _ = events[-1]["stop"] + _, _, _, tru_request_state, _, _, _ = events[-1]["stop"] # Verify existing request_state was preserved assert tru_request_state == initial_request_state @@ -969,7 +972,7 @@ def interrupt_callback(event): stream = strands.event_loop.event_loop.event_loop_cycle(agent, invocation_state={}) events = await alist(stream) - tru_stop_reason, _, _, _, tru_interrupts, _ = events[-1]["stop"] + tru_stop_reason, _, _, _, tru_interrupts, _, _ = events[-1]["stop"] exp_stop_reason = "interrupt" exp_interrupts = [ Interrupt( @@ -1064,7 +1067,7 @@ def interrupt_callback(event): stream = strands.event_loop.event_loop.event_loop_cycle(agent, invocation_state={}) events = await alist(stream) - tru_stop_reason, _, _, _, _, _ = events[-1]["stop"] + tru_stop_reason, _, _, _, _, _, _ = events[-1]["stop"] exp_stop_reason = "end_turn" assert tru_stop_reason == exp_stop_reason @@ -1196,5 +1199,99 @@ async def test_event_loop_metrics_recorded_before_recursion( assert mock_end_cycle.call_count == 2 # Verify the event loop completed successfully - tru_stop_reason, _, _, _, _, _ = events[-1]["stop"] + tru_stop_reason, _, _, _, _, _, _ = events[-1]["stop"] assert tru_stop_reason == "end_turn" + + +# --- Checkpoint event loop integration (Tasks 9-10) --- + + +@pytest.mark.asyncio +async def test_event_loop_cycle_checkpoint_after_model( + agent, + model, + tool_stream, + agenerator, + alist, +): + """With checkpointing=True, tool_use stop_reason yields after_model checkpoint instead of running tools.""" + agent._checkpointing = True + agent._checkpoint_resume_context = None + agent.take_snapshot = unittest.mock.Mock( + return_value=unittest.mock.Mock(to_dict=lambda: {"data": {"messages": []}}) + ) + + model.stream.return_value = agenerator(tool_stream) + + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + events = await alist(stream) + stop = events[-1]["stop"] + tru_stop_reason, _, _, _, _, _, tru_checkpoint = stop + + assert tru_stop_reason == "checkpoint" + assert tru_checkpoint is not None + assert tru_checkpoint.position == "after_model" + assert tru_checkpoint.cycle_index == 0 + + +@pytest.mark.asyncio +async def test_event_loop_cycle_checkpoint_after_tools( + agent, + model, + tool, + tool_stream, + agenerator, + alist, +): + """With checkpointing=True and resume from after_model, tools execute then yield after_tools checkpoint.""" + agent._checkpointing = True + agent.take_snapshot = unittest.mock.Mock( + return_value=unittest.mock.Mock(to_dict=lambda: {"data": {"messages": []}}) + ) + agent._checkpoint_resume_context = Checkpoint(position="after_model", cycle_index=0) + + model.stream.return_value = agenerator(tool_stream) + + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + events = await alist(stream) + tru_stop_reason, _, _, _, _, _, tru_checkpoint = events[-1]["stop"] + + assert tru_stop_reason == "checkpoint" + assert tru_checkpoint is not None + assert tru_checkpoint.position == "after_tools" + assert tru_checkpoint.cycle_index == 0 + + +@pytest.mark.asyncio +async def test_event_loop_cycle_checkpoint_resume_after_tools_increments_cycle( + agent, + model, + tool_stream, + agenerator, + alist, +): + """Resuming from after_tools sets cycle_index to previous + 1 for the next after_model checkpoint.""" + agent._checkpointing = True + agent._checkpoint_resume_context = Checkpoint(position="after_tools", cycle_index=2) + agent.take_snapshot = unittest.mock.Mock( + return_value=unittest.mock.Mock(to_dict=lambda: {"data": {"messages": []}}) + ) + + model.stream.return_value = agenerator(tool_stream) + + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + events = await alist(stream) + tru_stop_reason, _, _, _, _, _, tru_checkpoint = events[-1]["stop"] + + assert tru_stop_reason == "checkpoint" + assert tru_checkpoint.position == "after_model" + assert tru_checkpoint.cycle_index == 3 diff --git a/tests/strands/event_loop/test_event_loop_structured_output.py b/tests/strands/event_loop/test_event_loop_structured_output.py index 2d1150712..b8db1fd3d 100644 --- a/tests/strands/event_loop/test_event_loop_structured_output.py +++ b/tests/strands/event_loop/test_event_loop_structured_output.py @@ -60,6 +60,9 @@ def mock_agent(): agent._interrupt_state.activated = False agent._interrupt_state.context = {} agent._cancel_signal = threading.Event() + agent._model_state = {} + agent._checkpointing = False + agent._checkpoint_resume_context = None return agent diff --git a/tests/strands/experimental/checkpoint/test_checkpoint.py b/tests/strands/experimental/checkpoint/test_checkpoint.py index 4435fb3db..ea1221c09 100644 --- a/tests/strands/experimental/checkpoint/test_checkpoint.py +++ b/tests/strands/experimental/checkpoint/test_checkpoint.py @@ -2,52 +2,208 @@ import pytest +from strands import Agent, tool from strands.experimental.checkpoint import CHECKPOINT_SCHEMA_VERSION, Checkpoint +from strands.types.exceptions import CheckpointException +from tests.fixtures.mocked_model_provider import MockedModelProvider -class TestCheckpoint: - """Checkpoint dataclass serialization tests.""" - - def test_round_trip(self): - checkpoint = Checkpoint( - position="after_model", - cycle_index=1, - snapshot={"messages": []}, - app_data={"workflow_id": "wf-123"}, - ) - data = checkpoint.to_dict() - restored = Checkpoint.from_dict(data) - - assert restored.position == checkpoint.position - assert restored.cycle_index == checkpoint.cycle_index - assert restored.snapshot == checkpoint.snapshot - assert restored.app_data == checkpoint.app_data - assert restored.schema_version == CHECKPOINT_SCHEMA_VERSION - - def test_schema_version_immutable(self): - checkpoint = Checkpoint(position="after_tools") - assert checkpoint.schema_version == CHECKPOINT_SCHEMA_VERSION - - def test_schema_version_mismatch_raises(self): - data = Checkpoint(position="after_model").to_dict() - data["schema_version"] = "0.0" - with pytest.raises(ValueError, match="not compatible with current version"): - Checkpoint.from_dict(data) - - def test_defaults(self): - checkpoint = Checkpoint(position="after_model") - assert checkpoint.cycle_index == 0 - assert checkpoint.snapshot == {} - assert checkpoint.app_data == {} - - def test_from_dict_warns_on_unknown_fields(self, caplog): - data = Checkpoint(position="after_tools").to_dict() - data["unknown_future_field"] = "something" - restored = Checkpoint.from_dict(data) - assert restored.position == "after_tools" - assert "unknown_future_field" in caplog.text - - def test_from_dict_missing_schema_version_raises(self): - data = {"position": "after_model", "cycle_index": 0, "snapshot": {}, "app_data": {}} - with pytest.raises(ValueError, match="not compatible with current version"): - Checkpoint.from_dict(data) +def test_checkpoint_to_dict_from_dict_round_trip(): + checkpoint = Checkpoint( + position="after_model", + cycle_index=1, + snapshot={"messages": []}, + app_data={"workflow_id": "wf-123"}, + ) + data = checkpoint.to_dict() + restored = Checkpoint.from_dict(data) + + assert restored.position == checkpoint.position + assert restored.cycle_index == checkpoint.cycle_index + assert restored.snapshot == checkpoint.snapshot + assert restored.app_data == checkpoint.app_data + assert restored.schema_version == CHECKPOINT_SCHEMA_VERSION + + +def test_checkpoint_init_schema_version_immutable(): + checkpoint = Checkpoint(position="after_tools") + assert checkpoint.schema_version == CHECKPOINT_SCHEMA_VERSION + + +def test_checkpoint_init_defaults(): + checkpoint = Checkpoint(position="after_model") + assert checkpoint.cycle_index == 0 + assert checkpoint.snapshot == {} + assert checkpoint.app_data == {} + + +def test_checkpoint_from_dict_schema_version_mismatch_raises(): + data = Checkpoint(position="after_model").to_dict() + data["schema_version"] = "0.0" + with pytest.raises(CheckpointException, match="not compatible with current version"): + Checkpoint.from_dict(data) + + +def test_checkpoint_from_dict_missing_schema_version_raises(): + data = {"position": "after_model", "cycle_index": 0, "snapshot": {}, "app_data": {}} + with pytest.raises(CheckpointException, match="not compatible with current version"): + Checkpoint.from_dict(data) + + +def test_checkpoint_from_dict_unknown_fields_warns(caplog): + data = Checkpoint(position="after_tools").to_dict() + data["unknown_future_field"] = "something" + restored = Checkpoint.from_dict(data) + assert restored.position == "after_tools" + assert "unknown_future_field" in caplog.text + + +# ========================================================================= +# End-to-end integration tests (Part B) +# +# These tests exercise the full pause/resume cycle through agent.invoke_async, +# using real Agent instances (not mocks) and a scripted model provider. They prove: +# +# 1. Checkpoints round-trip through to_dict/from_dict across fresh Agent instances. +# 2. cycle_index is preserved across process-restart-style resumes. +# 3. Completed tool work survives worker loss — tools do not re-execute on resume. +# +# They do NOT cover mid-tool crashes (orchestrator responsibility) or stateful +# model server-side state (documented V0 limitation). +# ========================================================================= + + +def _assistant_tool_use(tool_use_id: str, name: str, input_data: dict) -> dict: + """Build a scripted assistant message that invokes a single tool.""" + return { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": tool_use_id, "name": name, "input": input_data}}], + } + + +def _assistant_text(text: str) -> dict: + return {"role": "assistant", "content": [{"text": text}]} + + +@pytest.mark.asyncio +async def test_checkpoint_round_trip_across_cycles() -> None: + """Fresh Agent pauses, serialize/deserialize, new Agent resumes, runs to completion.""" + call_log: list[str] = [] + + @tool + def noop(step: str) -> str: + call_log.append(step) + return f"ran-{step}" + + scripted_model = MockedModelProvider( + [ + _assistant_tool_use("t1", "noop", {"step": "one"}), + _assistant_text("done"), + ] + ) + + agent_a = Agent(model=scripted_model, tools=[noop], checkpointing=True) + + # Cycle 0 — model requests tool, pause at after_model. + result_after_model = await agent_a.invoke_async("please run a tool") + assert result_after_model.stop_reason == "checkpoint" + assert result_after_model.checkpoint is not None + assert result_after_model.checkpoint.position == "after_model" + assert result_after_model.checkpoint.cycle_index == 0 + assert call_log == [] # tool has not yet run + + # Serialize/deserialize — simulates crossing a process or activity boundary. + checkpoint_wire = result_after_model.checkpoint.to_dict() + resumed_checkpoint = Checkpoint.from_dict(checkpoint_wire) + + # Fresh Agent instance resumes and runs tools, pausing at after_tools. + agent_b = Agent(model=scripted_model, tools=[noop], checkpointing=True) + result_after_tools = await agent_b.invoke_async( + [{"checkpointResume": {"checkpoint": resumed_checkpoint.to_dict()}}] + ) + assert result_after_tools.stop_reason == "checkpoint" + assert result_after_tools.checkpoint is not None + assert result_after_tools.checkpoint.position == "after_tools" + assert result_after_tools.checkpoint.cycle_index == 0 + assert call_log == ["one"] # tool ran exactly once + + # Resume once more — model returns end_turn, agent completes. + agent_c = Agent(model=scripted_model, tools=[noop], checkpointing=True) + result_done = await agent_c.invoke_async( + [{"checkpointResume": {"checkpoint": result_after_tools.checkpoint.to_dict()}}] + ) + assert result_done.stop_reason == "end_turn" + assert result_done.checkpoint is None + # Tool still only ran once across the whole durable run. + assert call_log == ["one"] + + +@pytest.mark.asyncio +async def test_crash_after_tools_does_not_rerun_completed_tools() -> None: + """3 tools run, agent is discarded ('crash'), fresh agent resumes, tools do not re-run.""" + calls_alpha: list[str] = [] + calls_beta: list[str] = [] + calls_gamma: list[str] = [] + + @tool + def alpha(payload: str) -> str: + calls_alpha.append(payload) + return f"alpha-{payload}" + + @tool + def beta(payload: str) -> str: + calls_beta.append(payload) + return f"beta-{payload}" + + @tool + def gamma(payload: str) -> str: + calls_gamma.append(payload) + return f"gamma-{payload}" + + # One assistant message requests all three tools, then an end_turn response. + scripted_model = MockedModelProvider( + [ + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "t1", "name": "alpha", "input": {"payload": "a"}}}, + {"toolUse": {"toolUseId": "t2", "name": "beta", "input": {"payload": "b"}}}, + {"toolUse": {"toolUseId": "t3", "name": "gamma", "input": {"payload": "c"}}}, + ], + }, + _assistant_text("all done"), + ] + ) + + # Pre-crash agent: runs through after_model and after_tools. + pre_crash = Agent(model=scripted_model, tools=[alpha, beta, gamma], checkpointing=True) + after_model = await pre_crash.invoke_async("run the three tools") + assert after_model.stop_reason == "checkpoint" + assert after_model.checkpoint.position == "after_model" + + # Resume to run the tools. + pre_crash_b = Agent(model=scripted_model, tools=[alpha, beta, gamma], checkpointing=True) + after_tools = await pre_crash_b.invoke_async( + [{"checkpointResume": {"checkpoint": after_model.checkpoint.to_dict()}}] + ) + assert after_tools.stop_reason == "checkpoint" + assert after_tools.checkpoint.position == "after_tools" + # Exactly one call each, no double-runs. + assert calls_alpha == ["a"] + assert calls_beta == ["b"] + assert calls_gamma == ["c"] + + # "Crash": discard pre_crash_b entirely. Persist only the serialized checkpoint. + persisted = after_tools.checkpoint.to_dict() + del pre_crash, pre_crash_b + + # Post-crash: brand-new agent resumes from the after_tools checkpoint. + # The next model response is end_turn — no more tool use. + post_crash = Agent(model=scripted_model, tools=[alpha, beta, gamma], checkpointing=True) + final = await post_crash.invoke_async([{"checkpointResume": {"checkpoint": persisted}}]) + + assert final.stop_reason == "end_turn" + # No tool re-executed: call counts are unchanged. + assert calls_alpha == ["a"] + assert calls_beta == ["b"] + assert calls_gamma == ["c"] diff --git a/tests/strands/experimental/checkpoint/test_types.py b/tests/strands/experimental/checkpoint/test_types.py new file mode 100644 index 000000000..447a11e12 --- /dev/null +++ b/tests/strands/experimental/checkpoint/test_types.py @@ -0,0 +1,17 @@ +"""Tests for CheckpointResume content-block types.""" + +from strands.experimental.checkpoint import CheckpointResumeContent, CheckpointResumeDict + + +def test_checkpoint_resume_dict_carries_serialized_checkpoint() -> None: + block: CheckpointResumeDict = {"checkpoint": {"position": "after_model", "cycle_index": 0}} + assert block["checkpoint"]["position"] == "after_model" + assert block["checkpoint"]["cycle_index"] == 0 + + +def test_checkpoint_resume_content_wraps_resume_dict() -> None: + content: CheckpointResumeContent = { + "checkpointResume": {"checkpoint": {"position": "after_tools", "cycle_index": 2}} + } + assert content["checkpointResume"]["checkpoint"]["cycle_index"] == 2 + assert content["checkpointResume"]["checkpoint"]["position"] == "after_tools" diff --git a/tests/strands/types/test__events.py b/tests/strands/types/test__events.py index 48465e1f6..4f10bcce9 100644 --- a/tests/strands/types/test__events.py +++ b/tests/strands/types/test__events.py @@ -4,6 +4,7 @@ from pydantic import BaseModel +from strands.experimental.checkpoint import Checkpoint from strands.telemetry import EventLoopMetrics from strands.types._events import ( AgentAsToolStreamEvent, @@ -279,7 +280,7 @@ def test_initialization_without_structured_output(self): request_state = {"state": "final"} event = EventLoopStopEvent(stop_reason, message, metrics, request_state) - assert event["stop"] == (stop_reason, message, metrics, request_state, None, None) + assert event["stop"] == (stop_reason, message, metrics, request_state, None, None, None) assert event.is_callback_event is False def test_initialization_with_structured_output(self): @@ -291,7 +292,7 @@ def test_initialization_with_structured_output(self): structured_output = SampleModel(name="test", value=42) event = EventLoopStopEvent(stop_reason, message, metrics, request_state, structured_output) - assert event["stop"] == (stop_reason, message, metrics, request_state, structured_output, None) + assert event["stop"] == (stop_reason, message, metrics, request_state, structured_output, None, None) assert event.is_callback_event is False @@ -502,3 +503,35 @@ def test_is_tool_stream_event_subclass(self): assert isinstance(event, ToolStreamEvent) assert isinstance(event, TypedEvent) assert type(event) is AgentAsToolStreamEvent + + +# ========================================================================= +# EventLoopStopEvent checkpoint kwarg (Part B) +# ========================================================================= + + +def test_event_loop_stop_event_carries_checkpoint() -> None: + checkpoint = Checkpoint(position="after_model", cycle_index=0) + event = EventLoopStopEvent( + "checkpoint", + {"role": "assistant", "content": [{"text": "hi"}]}, + EventLoopMetrics(), + {}, + checkpoint=checkpoint, + ) + stop = event["stop"] + assert len(stop) == 7 + assert stop[0] == "checkpoint" + assert stop[6] is checkpoint + + +def test_event_loop_stop_event_checkpoint_defaults_to_none() -> None: + event = EventLoopStopEvent( + "end_turn", + {"role": "assistant", "content": [{"text": "done"}]}, + EventLoopMetrics(), + {}, + ) + stop = event["stop"] + assert len(stop) == 7 + assert stop[6] is None From e3df9d69ebd5d5aab2c01b09370b1986c201479b Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Wed, 22 Apr 2026 17:25:42 -0400 Subject: [PATCH 2/2] fix(checkpoint): address review feedback --- src/strands/agent/agent.py | 10 ++- src/strands/event_loop/event_loop.py | 76 ++++++++++++------- .../experimental/checkpoint/checkpoint.py | 6 +- .../checkpoint/test_checkpoint.py | 9 ++- tests/strands/models/test_model.py | 37 +++++---- 5 files changed, 88 insertions(+), 50 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index e3ea9e860..1ffb8c7a1 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -1011,10 +1011,14 @@ async def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: # Resume detection — must run before existing shape handling so checkpointResume # blocks aren't misinterpreted as content blocks. Mirrors _InterruptState.resume() - # conventions (TypeError for shape, KeyError for lookup, ValueError for misconfig; - # error messages use the SDK's key= | message format). + # conventions (TypeError for shape, KeyError for lookup, ValueError for misconfig) + # with one addition: a schema-version mismatch raises CheckpointException (the + # SDK-wide convention for checkpoint errors; parallel to SessionException and + # SnapshotException). Error messages use the SDK's key= | message format. if isinstance(prompt, list) and prompt: - has_checkpoint_resume = any(isinstance(c, dict) and "checkpointResume" in c for c in prompt) + has_checkpoint_resume = any( + isinstance(content, dict) and "checkpointResume" in content for content in prompt + ) if has_checkpoint_resume: if not self._checkpointing: raise ValueError( diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 8735648b8..11fa23967 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -15,7 +15,7 @@ from opentelemetry import trace as trace_api -from ..experimental.checkpoint import Checkpoint +from ..experimental.checkpoint import Checkpoint, CheckpointPosition from ..hooks import AfterModelCallEvent, BeforeModelCallEvent, MessageAddedEvent from ..telemetry.metrics import Trace from ..telemetry.tracer import Tracer, get_tracer @@ -76,6 +76,33 @@ def _has_tool_use_in_latest_message(messages: "Messages") -> bool: return False +def _build_checkpoint_stop_event( + agent: "Agent", + position: "CheckpointPosition", + cycle_index: int, + message: Message, + request_state: Any, +) -> EventLoopStopEvent: + """Build the EventLoopStopEvent that pauses the agent at a checkpoint boundary. + + Called from the two checkpoint emission sites (after_model in ``event_loop_cycle`` + and after_tools in ``_handle_tool_execution``) to keep snapshot capture and event + construction in one place. + """ + checkpoint = Checkpoint( + position=position, + cycle_index=cycle_index, + snapshot=agent.take_snapshot(preset="session").to_dict(), + ) + return EventLoopStopEvent( + "checkpoint", + message, + agent.event_loop_metrics, + request_state, + checkpoint=checkpoint, + ) + + async def event_loop_cycle( agent: "Agent", invocation_state: dict[str, Any], @@ -128,13 +155,15 @@ async def event_loop_cycle( # Cross-invocation state (resume context) lives on the agent; within-cycle # transient state (resume position for the skip check, cycle index) lives # in invocation_state. - resume_ctx = agent._checkpoint_resume_context - if resume_ctx is not None: + resume_context = agent._checkpoint_resume_context + if resume_context is not None: agent._checkpoint_resume_context = None # after_tools completed that cycle, so next cycle starts at +1 - next_cycle = resume_ctx.cycle_index + 1 if resume_ctx.position == "after_tools" else resume_ctx.cycle_index + next_cycle = ( + resume_context.cycle_index + 1 if resume_context.position == "after_tools" else resume_context.cycle_index + ) invocation_state["_checkpoint_cycle_index"] = next_cycle - invocation_state["_checkpoint_resume_position"] = resume_ctx.position + invocation_state["_checkpoint_resume_position"] = resume_context.position attributes = {"event_loop_cycle_id": str(invocation_state.get("event_loop_cycle_id"))} cycle_start_time, cycle_trace = agent.event_loop_metrics.start_cycle(attributes=attributes) @@ -199,25 +228,20 @@ async def event_loop_cycle( # One-shot pop: safe because after_model always returns before reaching # after_tools, so the stashed position is only consumed once. if agent._checkpointing: - resume_pos = invocation_state.pop("_checkpoint_resume_position", None) - if resume_pos == "after_model": + resume_position = invocation_state.pop("_checkpoint_resume_position", None) + if resume_position == "after_model": pass # Just resumed here — skip re-checkpoint, proceed to tools else: cycle_index = invocation_state.get("_checkpoint_cycle_index", 0) - checkpoint = Checkpoint( - position="after_model", - cycle_index=cycle_index, - snapshot=agent.take_snapshot(preset="session").to_dict(), - ) agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) if cycle_span: tracer.end_event_loop_cycle_span(span=cycle_span, message=message) - yield EventLoopStopEvent( - "checkpoint", - message, - agent.event_loop_metrics, - invocation_state["request_state"], - checkpoint=checkpoint, + yield _build_checkpoint_stop_event( + agent=agent, + position="after_model", + cycle_index=cycle_index, + message=message, + request_state=invocation_state["request_state"], ) return @@ -631,20 +655,18 @@ async def _handle_tool_execution( return # Checkpoint after all tools complete, before the next model call. + # Note: checkpoints are only emitted on tool_use cycles. An agent whose model + # returns end_turn on the very first call completes normally with no checkpoint + # (there is nothing durable to resume from). if agent._checkpointing: cycle_index = invocation_state.get("_checkpoint_cycle_index", 0) invocation_state["_checkpoint_cycle_index"] = cycle_index + 1 - checkpoint = Checkpoint( + yield _build_checkpoint_stop_event( + agent=agent, position="after_tools", cycle_index=cycle_index, - snapshot=agent.take_snapshot(preset="session").to_dict(), - ) - yield EventLoopStopEvent( - "checkpoint", - message, - agent.event_loop_metrics, - invocation_state["request_state"], - checkpoint=checkpoint, + message=message, + request_state=invocation_state["request_state"], ) return diff --git a/src/strands/experimental/checkpoint/checkpoint.py b/src/strands/experimental/checkpoint/checkpoint.py index bb0a002b7..bde591b47 100644 --- a/src/strands/experimental/checkpoint/checkpoint.py +++ b/src/strands/experimental/checkpoint/checkpoint.py @@ -52,10 +52,14 @@ CheckpointPosition = Literal["after_model", "after_tools"] -@dataclass +@dataclass(frozen=True) class Checkpoint: """Pause point in the agent loop. Treat as opaque — pass back to resume. + Immutable by design: a checkpoint represents a captured moment. Mutating + one after creation would decouple it from the snapshot it was built with, + which is always a bug. Build a new Checkpoint if you need different values. + Attributes: position: What just completed (after_model or after_tools). cycle_index: Which ReAct loop cycle (0-based). diff --git a/tests/strands/experimental/checkpoint/test_checkpoint.py b/tests/strands/experimental/checkpoint/test_checkpoint.py index ea1221c09..a2979ba8f 100644 --- a/tests/strands/experimental/checkpoint/test_checkpoint.py +++ b/tests/strands/experimental/checkpoint/test_checkpoint.py @@ -18,10 +18,11 @@ def test_checkpoint_to_dict_from_dict_round_trip(): data = checkpoint.to_dict() restored = Checkpoint.from_dict(data) - assert restored.position == checkpoint.position - assert restored.cycle_index == checkpoint.cycle_index - assert restored.snapshot == checkpoint.snapshot - assert restored.app_data == checkpoint.app_data + # Full-object equality catches any future-added field that isn't round-tripped + # correctly, without requiring this test to be updated for every new field. + assert restored == checkpoint + # schema_version is init=False, so it is always set to the current constant — + # asserted once explicitly since dataclass equality covers it via __eq__. assert restored.schema_version == CHECKPOINT_SCHEMA_VERSION diff --git a/tests/strands/models/test_model.py b/tests/strands/models/test_model.py index 11d4c10b9..ed5a218c8 100644 --- a/tests/strands/models/test_model.py +++ b/tests/strands/models/test_model.py @@ -550,15 +550,21 @@ def test_all_content_types(self): messages = [ {"role": "user", "content": [{"text": "hello world!"}]}, - {"role": "assistant", "content": [ - {"toolUse": {"toolUseId": "1", "name": "my_tool", "input": {"q": "test"}}}, - {"reasoningContent": {"reasoningText": {"text": "Let me think."}}}, - {"guardContent": {"text": {"text": "Filtered."}}}, - {"citationsContent": {"content": [{"text": "Citation."}]}}, - ]}, - {"role": "user", "content": [ - {"toolResult": {"toolUseId": "1", "content": [{"text": "tool output here"}]}}, - ]}, + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "1", "name": "my_tool", "input": {"q": "test"}}}, + {"reasoningContent": {"reasoningText": {"text": "Let me think."}}}, + {"guardContent": {"text": {"text": "Filtered."}}}, + {"citationsContent": {"content": [{"text": "Citation."}]}}, + ], + }, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "1", "content": [{"text": "tool output here"}]}}, + ], + }, ] result = _estimate_tokens_with_heuristic( messages=messages, @@ -574,9 +580,12 @@ def test_non_serializable_inputs(self): result = _estimate_tokens_with_heuristic( messages=[ - {"role": "assistant", "content": [ - {"toolUse": {"toolUseId": "1", "name": "my_tool", "input": {"data": b"bytes"}}}, - ]}, + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "1", "name": "my_tool", "input": {"data": b"bytes"}}}, + ], + }, ], tool_specs=[{"name": "t", "inputSchema": {"json": {"default": b"bytes"}}}], ) @@ -598,9 +607,7 @@ def _block_tiktoken(name, *args, **kwargs): monkeypatch.setattr("builtins.__import__", _block_tiktoken) try: - result = await model.count_tokens( - messages=[{"role": "user", "content": [{"text": "hello world!"}]}] - ) + result = await model.count_tokens(messages=[{"role": "user", "content": [{"text": "hello world!"}]}]) assert result == 3 # ceil(12 / 4) finally: model_module._get_encoding.cache_clear()