diff --git a/AGENTS.md b/AGENTS.md index 3615e713a..8835b45c8 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -152,6 +152,8 @@ strands-agents/ │ │ │ ├── tools/ # Bidi tools │ │ │ ├── types/ # Bidi types │ │ │ └── _async/ # Async utilities +│ │ ├── checkpoint/ # Durable agent execution checkpoints +│ │ │ └── checkpoint.py # Checkpoint dataclass and serialization │ │ ├── hooks/ # Experimental hooks │ │ │ ├── events.py │ │ │ └── multiagent/ diff --git a/src/strands/experimental/__init__.py b/src/strands/experimental/__init__.py index 3c1d0ee46..cbd9a713e 100644 --- a/src/strands/experimental/__init__.py +++ b/src/strands/experimental/__init__.py @@ -3,7 +3,7 @@ This module implements experimental features that are subject to change in future revisions without notice. """ -from . import steering, tools +from . import checkpoint, steering, tools from .agent_config import config_to_agent -__all__ = ["config_to_agent", "tools", "steering"] +__all__ = ["checkpoint", "config_to_agent", "tools", "steering"] diff --git a/src/strands/experimental/checkpoint/__init__.py b/src/strands/experimental/checkpoint/__init__.py new file mode 100644 index 000000000..848cda6d6 --- /dev/null +++ b/src/strands/experimental/checkpoint/__init__.py @@ -0,0 +1,12 @@ +"""Experimental checkpoint types for durable agent execution. + +This module is experimental and subject to change in future revisions without notice. + +Checkpoints enable crash-resilient agent workflows by capturing agent state at +cycle boundaries in the agent loop. A durability provider (e.g. Temporal) can +persist checkpoints and resume from them after failures. +""" + +from .checkpoint import CHECKPOINT_SCHEMA_VERSION, Checkpoint, CheckpointPosition + +__all__ = ["CHECKPOINT_SCHEMA_VERSION", "Checkpoint", "CheckpointPosition"] diff --git a/src/strands/experimental/checkpoint/checkpoint.py b/src/strands/experimental/checkpoint/checkpoint.py new file mode 100644 index 000000000..f37e403c9 --- /dev/null +++ b/src/strands/experimental/checkpoint/checkpoint.py @@ -0,0 +1,94 @@ +"""Checkpoint system for durable agent execution. + +Checkpoints enable crash-resilient agent workflows by capturing agent state at +cycle boundaries in the agent loop. A durability provider (e.g. Temporal) can +persist checkpoints and resume from them after failures. + +Two checkpoint positions per ReAct cycle: +- after_model: model call completed, tools not yet executed. +- after_tools: all tools executed, next model call pending. + +Per-tool granularity is handled by the ToolExecutor abstraction (e.g. +TemporalToolExecutor routes each tool to a separate Temporal activity). +The SDK checkpoint operates at cycle boundaries. + +User-facing pattern (same as interrupts): +- Pause via stop_reason="checkpoint" on AgentResult +- State via AgentResult.checkpoint field +- Resume via checkpointResume content block in next agent() call + +V0 Known Limitations: +- Metrics reset on each resume call. The caller is responsible for aggregating + metrics across a durable run. EventLoopMetrics reflects only the current call. +- OpenAIResponsesModel(stateful=True) is not supported. The server-side + response_id (_model_state) is not captured in the snapshot. +- When position is "after_tools", AgentResult.message is the assistant message + that requested the tools; tool results are in the snapshot messages. +- BeforeInvocationEvent and AfterInvocationEvent fire on every resume call, + same as interrupts. Hooks counting invocations will see each resume as a + separate invocation. +- Per-tool granularity within a cycle requires a custom ToolExecutor + (e.g. TemporalToolExecutor). +""" + +import logging +from dataclasses import asdict, dataclass, field +from typing import Any, Literal + +logger = logging.getLogger(__name__) + +CHECKPOINT_SCHEMA_VERSION = "1.0" + +CheckpointPosition = Literal["after_model", "after_tools"] + + +@dataclass +class Checkpoint: + """Pause point in the agent loop. Treat as opaque — pass back to resume. + + Attributes: + position: What just completed (after_model or after_tools). + cycle_index: Which ReAct loop cycle (0-based). + snapshot: Serialized agent state as a dict, produced by ``Snapshot.to_dict()``. + Stored as ``dict[str, Any]`` (not a ``Snapshot`` object) because checkpoints + must be JSON-serializable for cross-process persistence. The consumer + reconstructs via ``Snapshot.from_dict()`` on resume. + app_data: Application-level internal state data. The SDK does not read + or modify this. Applications can store arbitrary data needed across + checkpoint boundaries (e.g. session context, workflow metadata). + Separate from ``Snapshot.app_data`` which captures agent-state-level + data managed by the SDK. + schema_version: Rejects mismatches on resume across schema versions. + """ + + position: CheckpointPosition + cycle_index: int = 0 + snapshot: dict[str, Any] = field(default_factory=dict) + app_data: dict[str, Any] = field(default_factory=dict) + schema_version: str = field(init=False, default=CHECKPOINT_SCHEMA_VERSION) + + def to_dict(self) -> dict[str, Any]: + """Serialize for persistence.""" + return asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "Checkpoint": + """Reconstruct from a dict produced by to_dict(). + + Args: + data: Serialized checkpoint data. + + Raises: + ValueError: If schema_version doesn't match the current version. + """ + version = data.get("schema_version", "") + if version != CHECKPOINT_SCHEMA_VERSION: + raise ValueError( + f"Checkpoints with schema version {version!r} are not compatible " + f"with current version {CHECKPOINT_SCHEMA_VERSION}." + ) + known_keys = {k for k in cls.__dataclass_fields__ if k != "schema_version"} + unknown_keys = set(data.keys()) - known_keys - {"schema_version"} + if unknown_keys: + logger.warning("unknown_keys=<%s> | ignoring unknown fields in checkpoint data", unknown_keys) + return cls(**{k: v for k, v in data.items() if k in known_keys}) diff --git a/src/strands/types/event_loop.py b/src/strands/types/event_loop.py index fca141327..73d4e2bc0 100644 --- a/src/strands/types/event_loop.py +++ b/src/strands/types/event_loop.py @@ -38,6 +38,7 @@ class Metrics(TypedDict, total=False): StopReason = Literal[ "cancelled", + "checkpoint", "content_filtered", "end_turn", "guardrail_intervened", @@ -49,6 +50,7 @@ class Metrics(TypedDict, total=False): """Reason for the model ending its response generation. - "cancelled": Agent execution was cancelled via agent.cancel() +- "checkpoint": Agent paused for durable checkpoint persistence - "content_filtered": Content was filtered due to policy violation - "end_turn": Normal completion of the response - "guardrail_intervened": Guardrail system intervened diff --git a/tests/strands/experimental/checkpoint/__init__.py b/tests/strands/experimental/checkpoint/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/experimental/checkpoint/test_checkpoint.py b/tests/strands/experimental/checkpoint/test_checkpoint.py new file mode 100644 index 000000000..4435fb3db --- /dev/null +++ b/tests/strands/experimental/checkpoint/test_checkpoint.py @@ -0,0 +1,53 @@ +"""Tests for strands.experimental.checkpoint — Checkpoint serialization.""" + +import pytest + +from strands.experimental.checkpoint import CHECKPOINT_SCHEMA_VERSION, Checkpoint + + +class TestCheckpoint: + """Checkpoint dataclass serialization tests.""" + + def test_round_trip(self): + checkpoint = Checkpoint( + position="after_model", + cycle_index=1, + snapshot={"messages": []}, + app_data={"workflow_id": "wf-123"}, + ) + data = checkpoint.to_dict() + restored = Checkpoint.from_dict(data) + + assert restored.position == checkpoint.position + assert restored.cycle_index == checkpoint.cycle_index + assert restored.snapshot == checkpoint.snapshot + assert restored.app_data == checkpoint.app_data + assert restored.schema_version == CHECKPOINT_SCHEMA_VERSION + + def test_schema_version_immutable(self): + checkpoint = Checkpoint(position="after_tools") + assert checkpoint.schema_version == CHECKPOINT_SCHEMA_VERSION + + def test_schema_version_mismatch_raises(self): + data = Checkpoint(position="after_model").to_dict() + data["schema_version"] = "0.0" + with pytest.raises(ValueError, match="not compatible with current version"): + Checkpoint.from_dict(data) + + def test_defaults(self): + checkpoint = Checkpoint(position="after_model") + assert checkpoint.cycle_index == 0 + assert checkpoint.snapshot == {} + assert checkpoint.app_data == {} + + def test_from_dict_warns_on_unknown_fields(self, caplog): + data = Checkpoint(position="after_tools").to_dict() + data["unknown_future_field"] = "something" + restored = Checkpoint.from_dict(data) + assert restored.position == "after_tools" + assert "unknown_future_field" in caplog.text + + def test_from_dict_missing_schema_version_raises(self): + data = {"position": "after_model", "cycle_index": 0, "snapshot": {}, "app_data": {}} + with pytest.raises(ValueError, match="not compatible with current version"): + Checkpoint.from_dict(data)