diff --git a/langgraph/checkpoint/redis/aio.py b/langgraph/checkpoint/redis/aio.py index 5717f04..6962694 100644 --- a/langgraph/checkpoint/redis/aio.py +++ b/langgraph/checkpoint/redis/aio.py @@ -1063,7 +1063,7 @@ async def aput( if self.cluster_mode: # For cluster mode, execute operation directly - await self._redis.json().set( # type: ignore[misc] + await self._redis.json().set( checkpoint_key, "$", checkpoint_data ) else: @@ -1146,7 +1146,7 @@ async def aput_writes( ) # Redis JSON.SET is an UPSERT by default - await self._redis.json().set(key, "$", cast(Any, write_obj)) # type: ignore[misc] + await self._redis.json().set(key, "$", cast(Any, write_obj)) created_keys.append(key) # Apply TTL to newly created keys @@ -1181,7 +1181,7 @@ async def aput_writes( # Add all write keys with their index as score for ordering zadd_mapping = {key: idx for idx, key in enumerate(write_keys)} - await self._redis.zadd(zset_key, zadd_mapping) + await self._redis.zadd(zset_key, zadd_mapping) # type: ignore[arg-type] # Apply TTL to registry key if configured if self.ttl_config and "default_ttl" in self.ttl_config: @@ -1243,7 +1243,7 @@ async def aput_writes( # Add all write keys with their index as score for ordering zadd_mapping = {key: idx for idx, key in enumerate(write_keys)} - pipeline.zadd(zset_key, zadd_mapping) + pipeline.zadd(zset_key, zadd_mapping) # type: ignore[arg-type] # Apply TTL to registry key if configured if self.ttl_config and "default_ttl" in self.ttl_config: @@ -1291,7 +1291,7 @@ async def aput_writes( zadd_mapping = { key: idx for idx, key in enumerate(write_keys) } - fallback_pipeline.zadd(zset_key, zadd_mapping) + fallback_pipeline.zadd(zset_key, zadd_mapping) # type: ignore[arg-type] if self.ttl_config and "default_ttl" in self.ttl_config: ttl_seconds = int( self.ttl_config.get("default_ttl") * 60 @@ -1304,14 +1304,16 @@ async def aput_writes( # Update has_writes flag separately for older Redis if checkpoint_key: try: - checkpoint_data = await self._redis.json().get(checkpoint_key) # type: ignore[misc] + checkpoint_data = await self._redis.json().get( + checkpoint_key + ) if isinstance( checkpoint_data, dict ) and not checkpoint_data.get("has_writes"): checkpoint_data["has_writes"] = True await self._redis.json().set( checkpoint_key, "$", checkpoint_data - ) # type: ignore[misc] + ) except Exception: # If this fails, it's not critical - the writes are still saved pass @@ -1477,7 +1479,7 @@ async def aget_channel_values( ) # Single JSON.GET operation to retrieve checkpoint with inline channel_values - checkpoint_data = await self._redis.json().get(checkpoint_key, "$.checkpoint") # type: ignore[misc] + checkpoint_data = await self._redis.json().get(checkpoint_key, "$.checkpoint") if not checkpoint_data: return {} diff --git a/langgraph/checkpoint/redis/ashallow.py b/langgraph/checkpoint/redis/ashallow.py index b005ba1..db8113c 100644 --- a/langgraph/checkpoint/redis/ashallow.py +++ b/langgraph/checkpoint/redis/ashallow.py @@ -94,7 +94,7 @@ async def __aexit__( _tb: Optional[TracebackType], ) -> None: if self._owns_its_client: - await self._redis.aclose() + await self._redis.aclose() # type: ignore[attr-defined] # RedisCluster doesn't have connection_pool attribute if getattr(self._redis, "connection_pool", None): coro = self._redis.connection_pool.disconnect() @@ -229,9 +229,6 @@ async def aput( # Create pipeline for all operations pipeline = self._redis.pipeline(transaction=False) - # Get the previous checkpoint ID to potentially clean up its writes - pipeline.json().get(checkpoint_key) - # Set the new checkpoint data pipeline.json().set(checkpoint_key, "$", checkpoint_data) @@ -240,41 +237,21 @@ async def aput( ttl_seconds = int(self.ttl_config.get("default_ttl") * 60) pipeline.expire(checkpoint_key, ttl_seconds) - # Execute pipeline to get prev data and set new data - results = await pipeline.execute() - prev_checkpoint_data = results[0] - - # Check if we need to clean up old writes - prev_checkpoint_id = None - if prev_checkpoint_data and isinstance(prev_checkpoint_data, dict): - prev_checkpoint_id = prev_checkpoint_data.get("checkpoint_id") - - # If checkpoint changed, clean up old writes in a second pipeline - if prev_checkpoint_id and prev_checkpoint_id != checkpoint["id"]: - thread_zset_key = f"write_keys_zset:{thread_id}:{checkpoint_ns}:shallow" - - # Create cleanup pipeline - cleanup_pipeline = self._redis.pipeline(transaction=False) - - # Get all existing write keys - cleanup_pipeline.zrange(thread_zset_key, 0, -1) - - # Delete the registry - cleanup_pipeline.delete(thread_zset_key) - - # Execute to get keys and delete registry - cleanup_results = await cleanup_pipeline.execute() - existing_write_keys = cleanup_results[0] + # Execute pipeline to set new checkpoint data + await pipeline.execute() - # If there are keys to delete, do it in another pipeline - if existing_write_keys: - delete_pipeline = self._redis.pipeline(transaction=False) - for old_key in existing_write_keys: - old_key_str = ( - old_key.decode() if isinstance(old_key, bytes) else old_key - ) - delete_pipeline.delete(old_key_str) - await delete_pipeline.execute() + # NOTE: We intentionally do NOT clean up old writes here. + # In the HITL (Human-in-the-Loop) flow, interrupt writes are saved via + # put_writes BEFORE the new checkpoint is saved. If we clean up writes + # when the checkpoint changes, we would delete the interrupt writes + # before they can be consumed when resuming. + # + # Writes are cleaned up in the following scenarios: + # 1. When delete_thread is called + # 2. When TTL expires (if configured) + # 3. When put_writes is called again for the same task/idx (overwrites) + # + # See Issue #133 for details on this bug fix. return next_config @@ -388,7 +365,7 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: ) # Single fetch gets everything inline - matching sync implementation - full_checkpoint_data = await self._redis.json().get(checkpoint_key) # type: ignore[misc] + full_checkpoint_data = await self._redis.json().get(checkpoint_key) if not full_checkpoint_data or not isinstance(full_checkpoint_data, dict): return None @@ -505,7 +482,11 @@ async def aput_writes( writes_objects.append(write_obj) # Thread-level sorted set for write keys - thread_zset_key = f"write_keys_zset:{thread_id}:{checkpoint_ns}:shallow" + # Use to_storage_safe_str for consistent key naming + safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns) + thread_zset_key = ( + f"write_keys_zset:{thread_id}:{safe_checkpoint_ns}:shallow" + ) # Collect all write keys write_keys = [] @@ -529,7 +510,7 @@ async def aput_writes( # Use thread-level sorted set zadd_mapping = {key: idx for idx, key in enumerate(write_keys)} - pipeline.zadd(thread_zset_key, zadd_mapping) + pipeline.zadd(thread_zset_key, zadd_mapping) # type: ignore[arg-type] # Apply TTL to registry key if configured if self.ttl_config and "default_ttl" in self.ttl_config: @@ -563,7 +544,7 @@ async def aget_channel_values( ) # Single JSON.GET operation to retrieve checkpoint with inline channel_values - checkpoint_data = await self._redis.json().get(checkpoint_key, "$.checkpoint") # type: ignore[misc] + checkpoint_data = await self._redis.json().get(checkpoint_key, "$.checkpoint") if not checkpoint_data: return {} @@ -631,7 +612,9 @@ async def _aload_pending_writes( return [] # Use thread-level sorted set - thread_zset_key = f"write_keys_zset:{thread_id}:{checkpoint_ns}:shallow" + # Use to_storage_safe_str for consistent key naming + safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns) + thread_zset_key = f"write_keys_zset:{thread_id}:{safe_checkpoint_ns}:shallow" try: # Check if we have any writes in the thread sorted set diff --git a/langgraph/checkpoint/redis/base.py b/langgraph/checkpoint/redis/base.py index 9f76f54..049c81a 100644 --- a/langgraph/checkpoint/redis/base.py +++ b/langgraph/checkpoint/redis/base.py @@ -563,7 +563,7 @@ def _load_writes_from_redis(self, write_key: str) -> List[Tuple[str, str, Any]]: return [] writes = [] - for write in result["writes"]: # type: ignore[call-overload] + for write in result["writes"]: writes.append( ( write["task_id"], @@ -636,17 +636,17 @@ def put_writes( # UPSERT case - only update specific fields if key_exists: # Update only channel, type, and blob fields - pipeline.set(key, "$.channel", write_obj["channel"]) - pipeline.set(key, "$.type", write_obj["type"]) - pipeline.set(key, "$.blob", write_obj["blob"]) + pipeline.json().set(key, "$.channel", write_obj["channel"]) + pipeline.json().set(key, "$.type", write_obj["type"]) + pipeline.json().set(key, "$.blob", write_obj["blob"]) else: # For new records, set the complete object - pipeline.set(key, "$", write_obj) + pipeline.json().set(key, "$", write_obj) created_keys.append(key) else: # INSERT case - only insert if doesn't exist if not key_exists: - pipeline.set(key, "$", write_obj) + pipeline.json().set(key, "$", write_obj) created_keys.append(key) pipeline.execute() diff --git a/langgraph/checkpoint/redis/key_registry.py b/langgraph/checkpoint/redis/key_registry.py index 688e449..a54b1f0 100644 --- a/langgraph/checkpoint/redis/key_registry.py +++ b/langgraph/checkpoint/redis/key_registry.py @@ -105,7 +105,7 @@ def register_write_keys_batch( ) # Use index as score to maintain order mapping = {key: idx for idx, key in enumerate(write_keys)} - self._redis.zadd(zset_key, mapping) + self._redis.zadd(zset_key, mapping) # type: ignore[arg-type] def get_write_keys( self, thread_id: str, checkpoint_ns: str, checkpoint_id: str @@ -215,7 +215,7 @@ async def register_write_keys_batch( thread_id, checkpoint_ns, checkpoint_id ) mapping = {key: idx for idx, key in enumerate(write_keys)} - await self._redis.zadd(zset_key, mapping) + await self._redis.zadd(zset_key, mapping) # type: ignore[arg-type] async def get_write_keys( self, thread_id: str, checkpoint_ns: str, checkpoint_id: str diff --git a/langgraph/checkpoint/redis/shallow.py b/langgraph/checkpoint/redis/shallow.py index 2e8b901..c957958 100644 --- a/langgraph/checkpoint/redis/shallow.py +++ b/langgraph/checkpoint/redis/shallow.py @@ -200,35 +200,9 @@ def put( thread_id, checkpoint_ns ) - # Get the previous checkpoint ID to clean up its writes - prev_checkpoint_data = self._redis.json().get(checkpoint_key) - prev_checkpoint_id = None - if prev_checkpoint_data and isinstance(prev_checkpoint_data, dict): - prev_checkpoint_id = prev_checkpoint_data.get("checkpoint_id") - with self._redis.pipeline(transaction=False) as pipeline: pipeline.json().set(checkpoint_key, "$", checkpoint_data) - # If checkpoint changed, clean up old writes - if prev_checkpoint_id and prev_checkpoint_id != checkpoint["id"]: - # Clean up writes from the previous checkpoint - thread_write_registry_key = ( - f"write_registry:{thread_id}:{checkpoint_ns}:shallow" - ) - - # Get all existing write keys and delete them - existing_write_keys = self._redis.zrange( - thread_write_registry_key, 0, -1 - ) - for old_key in existing_write_keys: - old_key_str = ( - old_key.decode() if isinstance(old_key, bytes) else old_key - ) - pipeline.delete(old_key_str) - - # Clear the registry - pipeline.delete(thread_write_registry_key) - # Apply TTL if configured if self.ttl_config and "default_ttl" in self.ttl_config: ttl_seconds = int(self.ttl_config.get("default_ttl") * 60) @@ -236,6 +210,19 @@ def put( pipeline.execute() + # NOTE: We intentionally do NOT clean up old writes here. + # In the HITL (Human-in-the-Loop) flow, interrupt writes are saved via + # put_writes BEFORE the new checkpoint is saved. If we clean up writes + # when the checkpoint changes, we would delete the interrupt writes + # before they can be consumed when resuming. + # + # Writes are cleaned up in the following scenarios: + # 1. When delete_thread is called + # 2. When TTL expires (if configured) + # 3. When put_writes is called again for the same task/idx (overwrites) + # + # See Issue #133 for details on this bug fix. + return next_config def list( @@ -501,8 +488,10 @@ def put_writes( writes_objects.append(write_obj) # THREAD-LEVEL REGISTRY: Only keep writes for the current checkpoint + # Use to_storage_safe_str for consistent key naming with delete_thread + safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns) thread_write_registry_key = ( - f"write_registry:{thread_id}:{checkpoint_ns}:shallow" + f"write_registry:{thread_id}:{safe_checkpoint_ns}:shallow" ) # Collect all write keys @@ -525,7 +514,7 @@ def put_writes( # THREAD-LEVEL REGISTRY: Store write keys in thread-level sorted set # These will be cleared when checkpoint changes zadd_mapping = {key: idx for idx, key in enumerate(write_keys)} - pipeline.zadd(thread_write_registry_key, zadd_mapping) + pipeline.zadd(thread_write_registry_key, zadd_mapping) # type: ignore[arg-type] # Note: We don't update has_writes on the checkpoint anymore # because put_writes can be called before the checkpoint exists @@ -550,8 +539,10 @@ def _load_pending_writes( # Use thread-level registry that only contains current checkpoint writes # All writes belong to the current checkpoint + # Use to_storage_safe_str for consistent key naming with delete_thread + safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns) thread_write_registry_key = ( - f"write_registry:{thread_id}:{checkpoint_ns}:shallow" + f"write_registry:{thread_id}:{safe_checkpoint_ns}:shallow" ) # Get all write keys from the thread's registry (already sorted by index) diff --git a/langgraph/store/redis/__init__.py b/langgraph/store/redis/__init__.py index 62f5aab..e3d4629 100644 --- a/langgraph/store/redis/__init__.py +++ b/langgraph/store/redis/__init__.py @@ -541,7 +541,7 @@ def _batch_search_ops( if not isinstance(store_doc, dict): try: store_doc = json.loads( - store_doc # type: ignore[arg-type] + store_doc ) # Attempt to parse if it's a JSON string except (json.JSONDecodeError, TypeError): logger.error(f"Failed to parse store_doc: {store_doc}") diff --git a/langgraph/store/redis/aio.py b/langgraph/store/redis/aio.py index 4e5c510..3e2dec4 100644 --- a/langgraph/store/redis/aio.py +++ b/langgraph/store/redis/aio.py @@ -376,7 +376,7 @@ async def __aexit__( # Close Redis connections if we own them if self._owns_its_client: - await self._redis.aclose() + await self._redis.aclose() # type: ignore[attr-defined] await self._redis.connection_pool.disconnect() async def abatch(self, ops: Iterable[Op]) -> list[Result]: @@ -781,7 +781,7 @@ async def _batch_search_ops( ) result_map[store_key] = doc # Fetch individually in cluster mode - store_doc_item = await self._redis.json().get(store_key) # type: ignore + store_doc_item = await self._redis.json().get(store_key) store_docs.append(store_doc_item) store_docs_raw = store_docs else: diff --git a/tests/test_issue_133_hitl_shallow.py b/tests/test_issue_133_hitl_shallow.py new file mode 100644 index 0000000..e6a4bfe --- /dev/null +++ b/tests/test_issue_133_hitl_shallow.py @@ -0,0 +1,573 @@ +""" +Regression tests for Issue #133: HITL Works Unexpected with AsyncShallowRedisSaver + +Problem Description: +When using LangGraph 1.0's Human-in-the-Loop (HITL) interruption functionality with +`AsyncShallowRedisSaver`, interruptions intermittently require double confirmation. + +Key Symptoms: +1. Inconsistent interrupt persistence: When calling `agent.get_state(...)` after resuming + from an interrupt, the interrupt data is missing from the checkpoint +2. Self-resolving behavior: After the first confirmation, the interrupt reappears on + subsequent checks +3. State inconsistency: Redis keys occasionally contain `__interrupt__` while having + `has_writes` set to `false` + +Root Cause Analysis: +The shallow saver's `aput` method cleans up ALL writes when a new checkpoint is saved. +But in the HITL flow: +1. Graph hits an interrupt +2. `put_writes` is called with the interrupt data (using the CURRENT checkpoint_id) +3. `put` is called to save a NEW checkpoint +4. The `put` method sees the checkpoint_id changed, so it cleans up writes from the + previous checkpoint - INCLUDING the interrupt writes that were just saved! + +This results in the interrupt being lost before it can be read when resuming. + +Note: Some async tests require Python 3.11+ because interrupt() uses get_config() +which needs TaskGroup context support only available in Python 3.11+. +""" + +import operator +import sys +from contextlib import asynccontextmanager +from typing import Annotated, Any, AsyncGenerator, Dict, TypedDict +from uuid import uuid4 + +import pytest +from langchain_core.messages import AnyMessage, HumanMessage + +# Skip marker for tests that require Python 3.11+ due to interrupt() async context requirements +requires_python_311 = pytest.mark.skipif( + sys.version_info < (3, 11), + reason="interrupt() in async context requires Python 3.11+ for TaskGroup support", +) +from langgraph.checkpoint.base import ( + Checkpoint, + CheckpointMetadata, + create_checkpoint, + empty_checkpoint, +) +from langgraph.graph import END, START, StateGraph +from langgraph.types import Command, Interrupt, interrupt +from redis.asyncio import Redis + +from langgraph.checkpoint.redis.ashallow import AsyncShallowRedisSaver +from langgraph.checkpoint.redis.shallow import ShallowRedisSaver + + +class AgentState(TypedDict): + """State for the test agent.""" + + messages: Annotated[list[AnyMessage], operator.add] + user_confirmed: bool + + +def review_node(state: AgentState) -> Dict[str, Any]: + """Node that interrupts for review.""" + print("-------- review_node: before interrupt --------") + + # This creates an Interrupt that needs to be persisted + user_input = interrupt( + {"question": "Do you approve?", "context": state["messages"]} + ) + + print(f"-------- review_node: after interrupt, user_input={user_input} --------") + return {"user_confirmed": user_input.get("approved", False)} + + +def process_node(state: AgentState) -> Dict[str, Any]: + """Node that processes after confirmation.""" + print( + f"-------- process_node: user_confirmed={state.get('user_confirmed')} --------" + ) + return {"messages": [HumanMessage(content="Processing complete")]} + + +@asynccontextmanager +async def create_async_shallow_saver( + redis_url: str, +) -> AsyncGenerator[AsyncShallowRedisSaver, None]: + """Create and setup an AsyncShallowRedisSaver.""" + async with AsyncShallowRedisSaver.from_conn_string(redis_url) as saver: + yield saver + + +@requires_python_311 +@pytest.mark.asyncio +async def test_hitl_interrupt_persists_in_shallow_saver(redis_url: str) -> None: + """ + Test that HITL interrupts are properly persisted in AsyncShallowRedisSaver. + + This is the main regression test for Issue #133. It verifies that: + 1. An interrupt is saved when the graph hits an interrupt node + 2. The interrupt is still present when we check the state + 3. The interrupt can be resumed with Command(resume=...) + """ + async with create_async_shallow_saver(redis_url) as saver: + # Build the graph with an interrupt node + builder = StateGraph(AgentState) + builder.add_node("review", review_node) + builder.add_node("process", process_node) + builder.add_edge(START, "review") + builder.add_edge("review", "process") + builder.add_edge("process", END) + + graph = builder.compile(checkpointer=saver) + + # Use unique thread ID + thread_id = f"test-hitl-{uuid4()}" + config = {"configurable": {"thread_id": thread_id}} + + # First invocation - should hit the interrupt + initial_state = await graph.ainvoke( + {"messages": [HumanMessage(content="Please review this")]}, + config=config, + ) + print(f"Initial state: {initial_state}") + + # Get the current state to check for pending interrupts + state = await graph.aget_state(config) + print(f"State after interrupt: {state}") + + # CRITICAL CHECK: The interrupt should be in the pending writes + assert state is not None, "State should not be None after interrupt" + assert hasattr(state, "tasks"), "State should have tasks attribute" + + # Check for the interrupt in the state + # In LangGraph, interrupts are available via state.tasks + has_interrupt = False + for task in state.tasks: + if hasattr(task, "interrupts") and task.interrupts: + has_interrupt = True + print(f"Found interrupt in task: {task.interrupts}") + break + + assert has_interrupt, ( + "Interrupt should be present in state after hitting interrupt node. " + "This is the core issue in #133 - the interrupt is being deleted prematurely." + ) + + # Resume the graph with the interrupt response + final_state = await graph.ainvoke( + Command(resume={"approved": True}), + config=config, + ) + print(f"Final state: {final_state}") + + # Verify the graph completed successfully + assert "messages" in final_state + assert final_state.get("user_confirmed") is True + + +@requires_python_311 +@pytest.mark.asyncio +async def test_hitl_interrupt_with_multiple_checkpoints(redis_url: str) -> None: + """ + Test HITL behavior when multiple checkpoints are created before the interrupt. + + This tests the scenario where the graph runs through several nodes before + hitting an interrupt, creating multiple checkpoint transitions. + """ + + class MultiStepState(TypedDict): + counter: int + messages: Annotated[list[AnyMessage], operator.add] + + def step1(state: MultiStepState) -> Dict[str, Any]: + return {"counter": state.get("counter", 0) + 1} + + def step2(state: MultiStepState) -> Dict[str, Any]: + return {"counter": state.get("counter", 0) + 1} + + def interrupt_node(state: MultiStepState) -> Dict[str, Any]: + result = interrupt({"counter": state["counter"]}) + return {"messages": [HumanMessage(content=f"Received: {result}")]} + + async with create_async_shallow_saver(redis_url) as saver: + builder = StateGraph(MultiStepState) + builder.add_node("step1", step1) + builder.add_node("step2", step2) + builder.add_node("interrupt", interrupt_node) + builder.add_edge(START, "step1") + builder.add_edge("step1", "step2") + builder.add_edge("step2", "interrupt") + builder.add_edge("interrupt", END) + + graph = builder.compile(checkpointer=saver) + + thread_id = f"test-multistep-{uuid4()}" + config = {"configurable": {"thread_id": thread_id}} + + # Run until interrupt + await graph.ainvoke( + {"counter": 0, "messages": []}, + config=config, + ) + + # Check state + state = await graph.aget_state(config) + print(f"State after multi-step run: {state}") + + # Verify counter was incremented + assert ( + state.values.get("counter") == 2 + ), "Counter should be 2 after step1 and step2" + + # Check for interrupt + has_interrupt = False + for task in state.tasks: + if hasattr(task, "interrupts") and task.interrupts: + has_interrupt = True + break + + assert has_interrupt, "Interrupt should be present after multi-step run" + + # Resume + final_state = await graph.ainvoke( + Command(resume={"value": "confirmed"}), + config=config, + ) + + assert len(final_state["messages"]) > 0 + + +@pytest.mark.asyncio +async def test_interrupt_write_order_timing(redis_url: str) -> None: + """ + Low-level test of the timing issue between put_writes and put. + + This test directly tests the checkpoint saver methods to verify that + writes saved via put_writes are not cleaned up by the subsequent put call. + """ + async with create_async_shallow_saver(redis_url) as saver: + thread_id = f"test-timing-{uuid4()}" + checkpoint_ns = "" + + # Create initial checkpoint + initial_checkpoint = empty_checkpoint() + initial_config = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + } + } + + # Save initial checkpoint + saved_config_1 = await saver.aput( + initial_config, + initial_checkpoint, + {"source": "input", "step": 0, "writes": {}}, + {}, + ) + print(f"Saved initial checkpoint: {saved_config_1}") + + # Now simulate the HITL flow: + # 1. Save writes (interrupt) with the CURRENT checkpoint ID + interrupt_data = Interrupt( + value={"question": "Approve?"}, id="test-interrupt-1" + ) + await saver.aput_writes( + saved_config_1, + [("__interrupt__", [interrupt_data])], + "interrupt_task", + ) + print("Saved interrupt write") + + # 2. Verify the write is present + tuple_after_write = await saver.aget_tuple(saved_config_1) + assert tuple_after_write is not None + print(f"Pending writes after save: {tuple_after_write.pending_writes}") + + # Check that the interrupt write is present + interrupt_writes = [ + w for w in tuple_after_write.pending_writes if w[1] == "__interrupt__" + ] + assert ( + len(interrupt_writes) > 0 + ), "Interrupt write should be present after aput_writes" + + # 3. Now save a new checkpoint (this is where the bug would trigger) + new_checkpoint = create_checkpoint(initial_checkpoint, {}, 1) + saved_config_2 = await saver.aput( + saved_config_1, + new_checkpoint, + {"source": "update", "step": 1, "writes": {}}, + {}, + ) + print(f"Saved new checkpoint: {saved_config_2}") + + # 4. CRITICAL: Check if the interrupt write is STILL present + # This is where the bug manifests - the write gets deleted + tuple_after_new_checkpoint = await saver.aget_tuple(saved_config_2) + assert tuple_after_new_checkpoint is not None + print( + f"Pending writes after new checkpoint: {tuple_after_new_checkpoint.pending_writes}" + ) + + # The interrupt should still be present! + # In the buggy version, this would fail because the write was cleaned up + interrupt_writes_after = [ + w + for w in tuple_after_new_checkpoint.pending_writes + if w[1] == "__interrupt__" + ] + + # Note: The expected behavior here depends on the design decision: + # - If writes should persist across checkpoints, this should pass + # - If writes should be associated with specific checkpoints, we need different logic + # For HITL to work, the interrupt should NOT be cleaned up prematurely + + +@requires_python_311 +@pytest.mark.asyncio +async def test_interrupt_state_consistency_across_get_state_calls( + redis_url: str, +) -> None: + """ + Test that interrupt state is consistent across multiple get_state calls. + + This tests the reported symptom where the interrupt is missing on one get_state + call but reappears on subsequent calls. + """ + + def simple_interrupt_node(state: AgentState) -> Dict[str, Any]: + interrupt({"prompt": "Continue?"}) + return {} + + async with create_async_shallow_saver(redis_url) as saver: + builder = StateGraph(AgentState) + builder.add_node("interrupt", simple_interrupt_node) + builder.add_edge(START, "interrupt") + builder.add_edge("interrupt", END) + + graph = builder.compile(checkpointer=saver) + + thread_id = f"test-consistency-{uuid4()}" + config = {"configurable": {"thread_id": thread_id}} + + # Run until interrupt + await graph.ainvoke( + {"messages": [], "user_confirmed": False}, + config=config, + ) + + # Check state multiple times to detect inconsistency + interrupt_present_results = [] + for i in range(5): + state = await graph.aget_state(config) + has_interrupt = any( + hasattr(task, "interrupts") and task.interrupts for task in state.tasks + ) + interrupt_present_results.append(has_interrupt) + print(f"Check {i+1}: interrupt_present={has_interrupt}") + + # All checks should be consistent + assert all( + result == interrupt_present_results[0] + for result in interrupt_present_results + ), f"Interrupt presence is inconsistent across checks: {interrupt_present_results}" + + # And the interrupt should actually be present + assert interrupt_present_results[0], "Interrupt should be present" + + +@pytest.mark.asyncio +async def test_direct_redis_key_inspection(redis_url: str) -> None: + """ + Test that directly inspects Redis keys to verify interrupt storage. + + This test examines the raw Redis data to understand what's being stored + and when it's being deleted. + """ + redis_client = Redis.from_url(redis_url) + + try: + async with create_async_shallow_saver(redis_url) as saver: + thread_id = f"test-inspect-{uuid4()}" + checkpoint_ns = "" + + # Create initial checkpoint + initial_checkpoint = empty_checkpoint() + initial_config = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + } + } + + saved_config = await saver.aput( + initial_config, + initial_checkpoint, + {"source": "input", "step": 0, "writes": {}}, + {}, + ) + + # Save interrupt write + interrupt_data = Interrupt(value={"test": "data"}, id="inspect-interrupt") + await saver.aput_writes( + saved_config, + [("__interrupt__", [interrupt_data])], + "inspect_task", + ) + + # Check Redis keys after interrupt write + all_keys_after_write = await redis_client.keys(f"*{thread_id}*") + print(f"Keys after interrupt write: {all_keys_after_write}") + + # Look for write keys + write_keys_after_write = [ + k for k in all_keys_after_write if b"checkpoint_write" in k + ] + print(f"Write keys after interrupt: {write_keys_after_write}") + + # Check the write registry + registry_key = f"write_keys_zset:{thread_id}:{checkpoint_ns}:shallow" + registry_contents = await redis_client.zrange(registry_key, 0, -1) + print(f"Write registry after interrupt: {registry_contents}") + + # Now save a new checkpoint + new_checkpoint = create_checkpoint(initial_checkpoint, {}, 1) + new_config = await saver.aput( + saved_config, + new_checkpoint, + {"source": "update", "step": 1, "writes": {}}, + {}, + ) + + # Check Redis keys after new checkpoint + all_keys_after_checkpoint = await redis_client.keys(f"*{thread_id}*") + print(f"Keys after new checkpoint: {all_keys_after_checkpoint}") + + # Look for write keys after new checkpoint + write_keys_after_checkpoint = [ + k for k in all_keys_after_checkpoint if b"checkpoint_write" in k + ] + print(f"Write keys after new checkpoint: {write_keys_after_checkpoint}") + + # Check the write registry after new checkpoint + registry_contents_after = await redis_client.zrange(registry_key, 0, -1) + print(f"Write registry after new checkpoint: {registry_contents_after}") + + # The write keys should still exist if the bug is fixed + # If the bug is present, the write keys will be deleted + + finally: + await redis_client.aclose() + + +def test_sync_hitl_interrupt_persists(redis_url: str) -> None: + """ + Test that HITL interrupts work with the sync ShallowRedisSaver. + + This tests the same issue but with the synchronous implementation. + """ + with ShallowRedisSaver.from_conn_string(redis_url) as saver: + builder = StateGraph(AgentState) + builder.add_node("review", review_node) + builder.add_node("process", process_node) + builder.add_edge(START, "review") + builder.add_edge("review", "process") + builder.add_edge("process", END) + + graph = builder.compile(checkpointer=saver) + + thread_id = f"test-sync-hitl-{uuid4()}" + config = {"configurable": {"thread_id": thread_id}} + + # First invocation - should hit the interrupt + initial_state = graph.invoke( + {"messages": [HumanMessage(content="Please review this")]}, + config=config, + ) + print(f"Sync initial state: {initial_state}") + + # Get the current state + state = graph.get_state(config) + print(f"Sync state after interrupt: {state}") + + # Check for interrupt + has_interrupt = False + for task in state.tasks: + if hasattr(task, "interrupts") and task.interrupts: + has_interrupt = True + break + + assert has_interrupt, "Interrupt should be present in sync shallow saver" + + # Resume + final_state = graph.invoke( + Command(resume={"approved": True}), + config=config, + ) + print(f"Sync final state: {final_state}") + + assert final_state.get("user_confirmed") is True + + +@requires_python_311 +@pytest.mark.asyncio +async def test_double_resume_not_required(redis_url: str) -> None: + """ + Test that verifies the interrupt doesn't require double confirmation. + + This specifically tests the symptom reported in Issue #133 where users + need to confirm twice for the interrupt to be processed. + """ + resume_count = 0 + + def counting_interrupt_node(state: AgentState) -> Dict[str, Any]: + nonlocal resume_count + resume_count += 1 + print(f"Interrupt node called, count: {resume_count}") + result = interrupt({"attempt": resume_count}) + return {"user_confirmed": True} + + async with create_async_shallow_saver(redis_url) as saver: + builder = StateGraph(AgentState) + builder.add_node("interrupt", counting_interrupt_node) + builder.add_edge(START, "interrupt") + builder.add_edge("interrupt", END) + + graph = builder.compile(checkpointer=saver) + + thread_id = f"test-double-resume-{uuid4()}" + config = {"configurable": {"thread_id": thread_id}} + + # Initial run - hits interrupt + await graph.ainvoke( + {"messages": [], "user_confirmed": False}, + config=config, + ) + + initial_resume_count = resume_count + print(f"Resume count after initial run: {initial_resume_count}") + + # First resume attempt + result = await graph.ainvoke( + Command(resume={"confirmed": True}), + config=config, + ) + + print(f"Result after first resume: {result}") + print(f"Resume count after first resume: {resume_count}") + + # Check if we completed or need another resume + state = await graph.aget_state(config) + + # If the graph is still at the interrupt, it means we need a double resume + # This would be the bug - we should complete on first resume + has_pending_interrupt = any( + hasattr(task, "interrupts") and task.interrupts for task in state.tasks + ) + + assert not has_pending_interrupt, ( + "Graph should complete after single resume, not require double confirmation. " + f"Resume was called {resume_count - initial_resume_count} time(s)." + ) + + # Verify we only entered the interrupt node once after the initial run + assert resume_count == initial_resume_count + 1, ( + f"Interrupt node should only be entered once after resume, " + f"but was entered {resume_count - initial_resume_count} times" + ) diff --git a/tests/test_shallow_sync.py b/tests/test_shallow_sync.py index 2a1a4e3..22a9e05 100644 --- a/tests/test_shallow_sync.py +++ b/tests/test_shallow_sync.py @@ -545,11 +545,19 @@ def test_shallow_saver_inline_storage(redis_url: str) -> None: redis_client.close() -def test_pr37_incomplete_writes_cleanup(redis_url: str) -> None: - """Test for PR #37: Complete cleanup of writes in ShallowRedisSaver. +def test_pr37_writes_persist_for_hitl_support(redis_url: str) -> None: + """Test for PR #37 updated for Issue #133: Writes persist across checkpoints for HITL. - This test verifies that old writes are properly cleaned up - when putting new writes, now that key generation is consistent. + This test verifies that writes are NOT cleaned up when new checkpoints are saved. + This is necessary to support Human-in-the-Loop (HITL) workflows where interrupt + writes are saved BEFORE the new checkpoint is created. + + Writes are cleaned up via: + 1. delete_thread - explicitly cleans up all data for a thread + 2. TTL expiration - if configured + 3. Overwrite - when put_writes is called with the same task_id and idx + + See Issue #133 for details on why this behavior is required. """ import uuid @@ -558,7 +566,7 @@ def test_pr37_incomplete_writes_cleanup(redis_url: str) -> None: with _saver(redis_url) as saver: # Create test data thread_id = f"test_thread_{uuid.uuid4()}" - checkpoint_ns = "" # Empty namespace - problematic case + checkpoint_ns = "" # Empty namespace checkpoint_id1 = str(uuid.uuid4()) checkpoint_id2 = str(uuid.uuid4()) @@ -639,12 +647,26 @@ def test_pr37_incomplete_writes_cleanup(redis_url: str) -> None: for key in sorted(write_keys_after_second): print(f" {key}") - # In a proper shallow implementation, old writes for different checkpoints - # should be cleaned up. This test verifies the cleanup works correctly. - # We expect only the writes from the second checkpoint (2 writes) - assert ( - len(write_keys_after_second) == 2 - ), f"Bug: old write keys not cleaned up properly. Expected 2, got {len(write_keys_after_second)}" + # Issue #133 fix: Writes now persist across checkpoint updates to support HITL. + # We expect all 4 writes to still exist (2 from checkpoint1 + 2 from checkpoint2) + assert len(write_keys_after_second) == 4, ( + f"Writes should persist across checkpoints for HITL support. " + f"Expected 4, got {len(write_keys_after_second)}" + ) + + # Verify that delete_thread properly cleans up writes + saver.delete_thread(thread_id) + + all_keys = redis_client.keys("*") + test_keys = [k for k in all_keys if thread_id in k] + write_keys_after_delete = [ + k for k in test_keys if k.startswith(CHECKPOINT_WRITE_PREFIX) + ] + + assert len(write_keys_after_delete) == 0, ( + f"delete_thread should clean up all writes. " + f"Expected 0, got {len(write_keys_after_delete)}" + ) finally: redis_client.close() diff --git a/tests/test_shallow_ulid_ttl_cache.py b/tests/test_shallow_ulid_ttl_cache.py index 1b801cd..0dfd881 100644 --- a/tests/test_shallow_ulid_ttl_cache.py +++ b/tests/test_shallow_ulid_ttl_cache.py @@ -396,8 +396,18 @@ def test_metadata_dict_handling(redis_url: str) -> None: assert retrieved.metadata["writes"]["nested"]["complex"] == "structure" -def test_put_writes_cleanup_old_writes(redis_url: str) -> None: - """Test cleanup of old writes when checkpoint changes.""" +def test_put_writes_persist_for_hitl_support(redis_url: str) -> None: + """Test that writes persist across checkpoints for HITL support (Issue #133). + + Writes are NOT cleaned up when a new checkpoint is saved because this breaks + Human-in-the-Loop (HITL) workflows where interrupt writes are saved BEFORE + the new checkpoint is created. + + Writes are cleaned up via: + 1. delete_thread - explicitly cleans up all data for a thread + 2. TTL expiration - if configured + 3. Overwrite - when put_writes is called with the same task_id and idx + """ with shallow_saver(redis_url) as saver: thread_id = str(uuid4()) @@ -424,7 +434,7 @@ def test_put_writes_cleanup_old_writes(redis_url: str) -> None: writes1 = [("channel1", "value1"), ("channel2", "value2")] saver.put_writes(result_config1, writes1, "task1") - # Create second checkpoint (should trigger cleanup - lines 587-612) + # Create second checkpoint config2: RunnableConfig = { "configurable": { "thread_id": thread_id, @@ -443,21 +453,34 @@ def test_put_writes_cleanup_old_writes(redis_url: str) -> None: result_config2 = saver.put(config2, checkpoint2, metadata2, {}) - # Add writes to second checkpoint (should clean up old writes) + # Add writes to second checkpoint writes2 = [("channel3", "value3")] saver.put_writes(result_config2, writes2, "task2") - # Verify old writes are cleaned up and new writes exist + # Verify all writes persist (for HITL support) retrieved = saver.get_tuple(result_config2) assert retrieved is not None - # Should only have writes from checkpoint2 + # Issue #133 fix: All writes should persist pending_writes = retrieved.pending_writes write_channels = {w[1] for w in pending_writes} assert "channel3" in write_channels - # Old writes should be cleaned up - assert "channel1" not in write_channels - assert "channel2" not in write_channels + # Old writes should also be present (for HITL support) + assert "channel1" in write_channels + assert "channel2" in write_channels + + # Verify that delete_thread properly cleans up all writes + saver.delete_thread(thread_id) + + # After delete_thread, writes should be cleaned up + # Note: The checkpoint itself may still exist briefly due to index lag, + # but the writes should be cleaned up immediately + retrieved_after_delete = saver.get_tuple(result_config2) + if retrieved_after_delete is not None: + # If checkpoint still exists, verify writes are cleaned up + assert ( + len(retrieved_after_delete.pending_writes) == 0 + ), "Writes should be cleaned up by delete_thread" def test_error_handling_missing_checkpoint(redis_url: str) -> None: