-
Notifications
You must be signed in to change notification settings - Fork 800
feat(checkpoint): wire checkpointing into agent event loop #2190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,6 +29,7 @@ | |
| from .._async import run_async | ||
| from ..event_loop._retry import ModelRetryStrategy | ||
| from ..event_loop.event_loop import INITIAL_DELAY, MAX_ATTEMPTS, MAX_DELAY, event_loop_cycle | ||
| from ..experimental.checkpoint import Checkpoint | ||
| from ..tools._tool_helpers import generate_missing_tool_result_content | ||
| from ..types._snapshot import ( | ||
| SNAPSHOT_SCHEMA_VERSION, | ||
|
|
@@ -146,6 +147,7 @@ def __init__( | |
| tool_executor: ToolExecutor | None = None, | ||
| retry_strategy: ModelRetryStrategy | _DefaultRetryStrategySentinel | None = _DEFAULT_RETRY_STRATEGY, | ||
| concurrent_invocation_mode: ConcurrentInvocationMode = ConcurrentInvocationMode.THROW, | ||
| checkpointing: bool = False, | ||
| ): | ||
| """Initialize the Agent with the specified configuration. | ||
|
|
||
|
|
@@ -214,6 +216,11 @@ def __init__( | |
| Set to "unsafe_reentrant" to skip lock acquisition entirely, allowing concurrent invocations. | ||
| Warning: "unsafe_reentrant" makes no guarantees about resulting behavior and is provided | ||
| only for advanced use cases where the caller understands the risks. | ||
| checkpointing: When True, the event loop pauses at cycle boundaries (after model call, | ||
| after all tools execute) and returns an AgentResult with stop_reason="checkpoint" | ||
| and a populated ``checkpoint`` field. Persist the checkpoint and resume by passing a | ||
| ``CheckpointResumeContent`` block as the next prompt. Defaults to False. | ||
| See :mod:`strands.experimental.checkpoint` for usage and limitations. | ||
|
|
||
| Raises: | ||
| ValueError: If agent id contains path separators. | ||
|
|
@@ -304,6 +311,10 @@ def __init__( | |
|
|
||
| self._interrupt_state = _InterruptState() | ||
|
|
||
| # Checkpointing: when True, event loop pauses at cycle boundaries | ||
| self._checkpointing: bool = checkpointing | ||
| self._checkpoint_resume_context: Checkpoint | None = None | ||
|
|
||
| # Runtime state for model providers (e.g., server-side response ids) | ||
| self._model_state: dict[str, Any] = {} | ||
|
|
||
|
|
@@ -998,6 +1009,50 @@ async def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: | |
| if self._interrupt_state.activated: | ||
| return [] | ||
|
|
||
| # Resume detection — must run before existing shape handling so checkpointResume | ||
| # blocks aren't misinterpreted as content blocks. Mirrors _InterruptState.resume() | ||
| # conventions (TypeError for shape, KeyError for lookup, ValueError for misconfig) | ||
| # with one addition: a schema-version mismatch raises CheckpointException (the | ||
| # SDK-wide convention for checkpoint errors; parallel to SessionException and | ||
| # SnapshotException). Error messages use the SDK's key=<value> | message format. | ||
| if isinstance(prompt, list) and prompt: | ||
| has_checkpoint_resume = any( | ||
| isinstance(content, dict) and "checkpointResume" in content for content in prompt | ||
| ) | ||
| if has_checkpoint_resume: | ||
| if not self._checkpointing: | ||
| raise ValueError( | ||
| "Received checkpointResume block but agent was created with " | ||
| "checkpointing=False. Pass checkpointing=True when constructing " | ||
| "the Agent to enable durable execution." | ||
| ) | ||
|
|
||
| invalid_types = [ | ||
| key | ||
| for content in prompt | ||
| if isinstance(content, dict) | ||
| for key in content | ||
| if key != "checkpointResume" | ||
| ] | ||
| if invalid_types: | ||
| raise TypeError( | ||
| f"content_types=<{invalid_types}> | checkpointResume cannot be mixed with other content types" | ||
| ) | ||
|
|
||
| if len(prompt) != 1: | ||
| raise TypeError( | ||
| f"block_count=<{len(prompt)}> | only one checkpointResume block permitted per prompt" | ||
| ) | ||
|
|
||
| resume_block = prompt[0].get("checkpointResume", {}) | ||
| if not isinstance(resume_block, dict) or "checkpoint" not in resume_block: | ||
| raise KeyError("checkpoint | missing required key in checkpointResume block") | ||
|
|
||
| checkpoint = Checkpoint.from_dict(resume_block["checkpoint"]) | ||
| self.load_snapshot(Snapshot.from_dict(checkpoint.snapshot)) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Issue: The Suggestion: Consider wrapping line 1047-1048 in a try/except that provides a clear error message, e.g., "Failed to restore agent state from checkpoint snapshot: {original_error}".
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair — |
||
| self._checkpoint_resume_context = checkpoint | ||
| return [] | ||
|
|
||
| messages: Messages | None = None | ||
| if prompt is not None: | ||
| # Check if the latest message is toolUse | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,7 @@ | |
|
|
||
| from opentelemetry import trace as trace_api | ||
|
|
||
| from ..experimental.checkpoint import Checkpoint, CheckpointPosition | ||
| from ..hooks import AfterModelCallEvent, BeforeModelCallEvent, MessageAddedEvent | ||
| from ..telemetry.metrics import Trace | ||
| from ..telemetry.tracer import Tracer, get_tracer | ||
|
|
@@ -75,6 +76,33 @@ def _has_tool_use_in_latest_message(messages: "Messages") -> bool: | |
| return False | ||
|
|
||
|
|
||
| def _build_checkpoint_stop_event( | ||
| agent: "Agent", | ||
| position: "CheckpointPosition", | ||
| cycle_index: int, | ||
| message: Message, | ||
| request_state: Any, | ||
| ) -> EventLoopStopEvent: | ||
| """Build the EventLoopStopEvent that pauses the agent at a checkpoint boundary. | ||
|
|
||
| Called from the two checkpoint emission sites (after_model in ``event_loop_cycle`` | ||
| and after_tools in ``_handle_tool_execution``) to keep snapshot capture and event | ||
| construction in one place. | ||
| """ | ||
| checkpoint = Checkpoint( | ||
| position=position, | ||
| cycle_index=cycle_index, | ||
| snapshot=agent.take_snapshot(preset="session").to_dict(), | ||
| ) | ||
| return EventLoopStopEvent( | ||
| "checkpoint", | ||
| message, | ||
| agent.event_loop_metrics, | ||
| request_state, | ||
| checkpoint=checkpoint, | ||
| ) | ||
|
|
||
|
|
||
| async def event_loop_cycle( | ||
| agent: "Agent", | ||
| invocation_state: dict[str, Any], | ||
|
|
@@ -122,6 +150,21 @@ async def event_loop_cycle( | |
| # Initialize state and get cycle trace | ||
| if "request_state" not in invocation_state: | ||
| invocation_state["request_state"] = {} | ||
|
|
||
| # Consume checkpoint resume context (one-shot: cleared after reading). | ||
| # Cross-invocation state (resume context) lives on the agent; within-cycle | ||
| # transient state (resume position for the skip check, cycle index) lives | ||
| # in invocation_state. | ||
| resume_context = agent._checkpoint_resume_context | ||
| if resume_context is not None: | ||
| agent._checkpoint_resume_context = None | ||
| # after_tools completed that cycle, so next cycle starts at +1 | ||
| next_cycle = ( | ||
| resume_context.cycle_index + 1 if resume_context.position == "after_tools" else resume_context.cycle_index | ||
| ) | ||
| invocation_state["_checkpoint_cycle_index"] = next_cycle | ||
| invocation_state["_checkpoint_resume_position"] = resume_context.position | ||
|
|
||
| attributes = {"event_loop_cycle_id": str(invocation_state.get("event_loop_cycle_id"))} | ||
| cycle_start_time, cycle_trace = agent.event_loop_metrics.start_cycle(attributes=attributes) | ||
| invocation_state["event_loop_cycle_trace"] = cycle_trace | ||
|
|
@@ -181,6 +224,27 @@ async def event_loop_cycle( | |
| ) | ||
|
|
||
| if stop_reason == "tool_use": | ||
| # Checkpoint after model call, before tool execution. | ||
| # One-shot pop: safe because after_model always returns before reaching | ||
| # after_tools, so the stashed position is only consumed once. | ||
| if agent._checkpointing: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Issue: When This is probably fine for durable-execution semantics (the orchestrator decides whether to re-cancel), but it's an unspecified interaction that could surprise users who expect Suggestion: Document the cancel-vs-checkpoint precedence (checkpoint wins, cancel is ignored) either in the module docstring or add a test that specifies the expected behavior. |
||
| resume_position = invocation_state.pop("_checkpoint_resume_position", None) | ||
| if resume_position == "after_model": | ||
| pass # Just resumed here — skip re-checkpoint, proceed to tools | ||
| else: | ||
| cycle_index = invocation_state.get("_checkpoint_cycle_index", 0) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Issue: When Also: the |
||
| agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) | ||
| if cycle_span: | ||
| tracer.end_event_loop_cycle_span(span=cycle_span, message=message) | ||
| yield _build_checkpoint_stop_event( | ||
| agent=agent, | ||
| position="after_model", | ||
| cycle_index=cycle_index, | ||
| message=message, | ||
| request_state=invocation_state["request_state"], | ||
| ) | ||
| return | ||
|
|
||
| # Handle tool execution | ||
| tool_events = _handle_tool_execution( | ||
| stop_reason, | ||
|
|
@@ -590,6 +654,22 @@ async def _handle_tool_execution( | |
| ) | ||
| return | ||
|
|
||
| # Checkpoint after all tools complete, before the next model call. | ||
| # Note: checkpoints are only emitted on tool_use cycles. An agent whose model | ||
| # returns end_turn on the very first call completes normally with no checkpoint | ||
| # (there is nothing durable to resume from). | ||
| if agent._checkpointing: | ||
| cycle_index = invocation_state.get("_checkpoint_cycle_index", 0) | ||
| invocation_state["_checkpoint_cycle_index"] = cycle_index + 1 | ||
| yield _build_checkpoint_stop_event( | ||
| agent=agent, | ||
| position="after_tools", | ||
| cycle_index=cycle_index, | ||
| message=message, | ||
| request_state=invocation_state["request_state"], | ||
| ) | ||
| return | ||
|
|
||
| events = recurse_event_loop( | ||
| agent=agent, invocation_state=invocation_state, structured_output_context=structured_output_context | ||
| ) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,14 @@ | |
| - State via AgentResult.checkpoint field | ||
| - Resume via checkpointResume content block in next agent() call | ||
|
|
||
| Interaction with interrupts: | ||
| - Interrupts take priority over checkpoints. If a tool raises an Interrupt | ||
| during a checkpointing=True cycle, the event loop returns | ||
| stop_reason="interrupt" (not "checkpoint"). The after_tools checkpoint | ||
| is never reached because the interrupt path returns early. | ||
| - This is intentional: interrupts require human input, checkpoints are | ||
| for worker-level durability. Different semantics, different priorities. | ||
|
|
||
| V0 Known Limitations: | ||
| - Metrics reset on each resume call. The caller is responsible for aggregating | ||
| metrics across a durable run. EventLoopMetrics reflects only the current call. | ||
|
|
@@ -35,17 +43,23 @@ | |
| from dataclasses import asdict, dataclass, field | ||
| from typing import Any, Literal | ||
|
|
||
| from ...types.exceptions import CheckpointException | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| CHECKPOINT_SCHEMA_VERSION = "1.0" | ||
|
|
||
| CheckpointPosition = Literal["after_model", "after_tools"] | ||
|
|
||
|
|
||
| @dataclass | ||
| @dataclass(frozen=True) | ||
| class Checkpoint: | ||
| """Pause point in the agent loop. Treat as opaque — pass back to resume. | ||
|
|
||
| Immutable by design: a checkpoint represents a captured moment. Mutating | ||
| one after creation would decouple it from the snapshot it was built with, | ||
| which is always a bug. Build a new Checkpoint if you need different values. | ||
|
|
||
| Attributes: | ||
| position: What just completed (after_model or after_tools). | ||
| cycle_index: Which ReAct loop cycle (0-based). | ||
|
|
@@ -79,11 +93,11 @@ def from_dict(cls, data: dict[str, Any]) -> "Checkpoint": | |
| data: Serialized checkpoint data. | ||
|
|
||
| Raises: | ||
| ValueError: If schema_version doesn't match the current version. | ||
| CheckpointException: If schema_version doesn't match the current version. | ||
| """ | ||
| version = data.get("schema_version", "") | ||
| if version != CHECKPOINT_SCHEMA_VERSION: | ||
| raise ValueError( | ||
| raise CheckpointException( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Issue: Suggestion: Since this is still experimental, the change is reasonable — but ensure the exception hierarchy is intentional. Consider whether |
||
| f"Checkpoints with schema version {version!r} are not compatible " | ||
| f"with current version {CHECKPOINT_SCHEMA_VERSION}." | ||
| ) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| """Content-block types for checkpoint resume. | ||
|
|
||
| Mirrors the interrupt pattern (`InterruptResponseContent` in `types/interrupt.py`). | ||
| Stays under `experimental/checkpoint/` for V0; will graduate to | ||
| `src/strands/types/checkpoint.py` when the feature exits experimental. | ||
| """ | ||
|
|
||
| from typing import Any, TypedDict | ||
|
|
||
|
|
||
| class CheckpointResumeDict(TypedDict): | ||
| """Inner payload for a checkpointResume content block. | ||
|
|
||
| Attributes: | ||
| checkpoint: Serialized Checkpoint as produced by ``Checkpoint.to_dict()``. | ||
| """ | ||
|
|
||
| checkpoint: dict[str, Any] | ||
|
|
||
|
|
||
| class CheckpointResumeContent(TypedDict): | ||
| """Content block that resumes a paused durable agent. | ||
|
|
||
| Pass a list containing exactly one instance of this type as the prompt to | ||
| ``Agent.invoke_async`` / ``Agent.__call__`` to resume from a checkpoint. | ||
|
|
||
| Example:: | ||
|
|
||
| result = await agent.invoke_async( | ||
| [{"checkpointResume": {"checkpoint": previous_checkpoint.to_dict()}}] | ||
| ) | ||
|
|
||
| Attributes: | ||
| checkpointResume: The resume payload carrying the serialized checkpoint. | ||
| """ | ||
|
|
||
| checkpointResume: CheckpointResumeDict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Issue: This introduces a new public API surface on
Agent.__init__(checkpointingparameter), a newAgentResultfield, newStopReasonvalue, and new content block types. Per the API Bar Raising process, this falls under "moderate changes" (adding a new class/abstraction customers use to achieve new behavior) and should have theneeds-api-reviewlabel.Suggestion: Add the
needs-api-reviewlabel to this PR and ensure an API reviewer evaluates the public API design before merge. Key API decisions to review:checkpointing: boolthe right granularity, or should it be an enum/config object for future extensibility (e.g., checkpoint only atafter_tools, custom positions)?checkpointResumeas a content block the right abstraction for resume, vs. a dedicated method likeagent.resume_from_checkpoint(checkpoint)?Checkpointbe a frozen dataclass to prevent accidental mutation?