Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions langgraph/checkpoint/redis/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,7 @@ async def aput(

if self.cluster_mode:
# For cluster mode, execute operation directly
await self._redis.json().set( # type: ignore[misc]
await self._redis.json().set(
checkpoint_key, "$", checkpoint_data
)
else:
Expand Down Expand Up @@ -1146,7 +1146,7 @@ async def aput_writes(
)

# Redis JSON.SET is an UPSERT by default
await self._redis.json().set(key, "$", cast(Any, write_obj)) # type: ignore[misc]
await self._redis.json().set(key, "$", cast(Any, write_obj))
created_keys.append(key)

# Apply TTL to newly created keys
Expand Down Expand Up @@ -1181,7 +1181,7 @@ async def aput_writes(

# Add all write keys with their index as score for ordering
zadd_mapping = {key: idx for idx, key in enumerate(write_keys)}
await self._redis.zadd(zset_key, zadd_mapping)
await self._redis.zadd(zset_key, zadd_mapping) # type: ignore[arg-type]

# Apply TTL to registry key if configured
if self.ttl_config and "default_ttl" in self.ttl_config:
Expand Down Expand Up @@ -1243,7 +1243,7 @@ async def aput_writes(

# Add all write keys with their index as score for ordering
zadd_mapping = {key: idx for idx, key in enumerate(write_keys)}
pipeline.zadd(zset_key, zadd_mapping)
pipeline.zadd(zset_key, zadd_mapping) # type: ignore[arg-type]

# Apply TTL to registry key if configured
if self.ttl_config and "default_ttl" in self.ttl_config:
Expand Down Expand Up @@ -1291,7 +1291,7 @@ async def aput_writes(
zadd_mapping = {
key: idx for idx, key in enumerate(write_keys)
}
fallback_pipeline.zadd(zset_key, zadd_mapping)
fallback_pipeline.zadd(zset_key, zadd_mapping) # type: ignore[arg-type]
if self.ttl_config and "default_ttl" in self.ttl_config:
ttl_seconds = int(
self.ttl_config.get("default_ttl") * 60
Expand All @@ -1304,14 +1304,16 @@ async def aput_writes(
# Update has_writes flag separately for older Redis
if checkpoint_key:
try:
checkpoint_data = await self._redis.json().get(checkpoint_key) # type: ignore[misc]
checkpoint_data = await self._redis.json().get(
checkpoint_key
)
if isinstance(
checkpoint_data, dict
) and not checkpoint_data.get("has_writes"):
checkpoint_data["has_writes"] = True
await self._redis.json().set(
checkpoint_key, "$", checkpoint_data
) # type: ignore[misc]
)
except Exception:
# If this fails, it's not critical - the writes are still saved
pass
Expand Down Expand Up @@ -1477,7 +1479,7 @@ async def aget_channel_values(
)

# Single JSON.GET operation to retrieve checkpoint with inline channel_values
checkpoint_data = await self._redis.json().get(checkpoint_key, "$.checkpoint") # type: ignore[misc]
checkpoint_data = await self._redis.json().get(checkpoint_key, "$.checkpoint")

if not checkpoint_data:
return {}
Expand Down
69 changes: 26 additions & 43 deletions langgraph/checkpoint/redis/ashallow.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ async def __aexit__(
_tb: Optional[TracebackType],
) -> None:
if self._owns_its_client:
await self._redis.aclose()
await self._redis.aclose() # type: ignore[attr-defined]
# RedisCluster doesn't have connection_pool attribute
if getattr(self._redis, "connection_pool", None):
coro = self._redis.connection_pool.disconnect()
Expand Down Expand Up @@ -229,9 +229,6 @@ async def aput(
# Create pipeline for all operations
pipeline = self._redis.pipeline(transaction=False)

# Get the previous checkpoint ID to potentially clean up its writes
pipeline.json().get(checkpoint_key)

# Set the new checkpoint data
pipeline.json().set(checkpoint_key, "$", checkpoint_data)

Expand All @@ -240,41 +237,21 @@ async def aput(
ttl_seconds = int(self.ttl_config.get("default_ttl") * 60)
pipeline.expire(checkpoint_key, ttl_seconds)

# Execute pipeline to get prev data and set new data
results = await pipeline.execute()
prev_checkpoint_data = results[0]

# Check if we need to clean up old writes
prev_checkpoint_id = None
if prev_checkpoint_data and isinstance(prev_checkpoint_data, dict):
prev_checkpoint_id = prev_checkpoint_data.get("checkpoint_id")

# If checkpoint changed, clean up old writes in a second pipeline
if prev_checkpoint_id and prev_checkpoint_id != checkpoint["id"]:
thread_zset_key = f"write_keys_zset:{thread_id}:{checkpoint_ns}:shallow"

# Create cleanup pipeline
cleanup_pipeline = self._redis.pipeline(transaction=False)

# Get all existing write keys
cleanup_pipeline.zrange(thread_zset_key, 0, -1)

# Delete the registry
cleanup_pipeline.delete(thread_zset_key)

# Execute to get keys and delete registry
cleanup_results = await cleanup_pipeline.execute()
existing_write_keys = cleanup_results[0]
# Execute pipeline to set new checkpoint data
await pipeline.execute()

# If there are keys to delete, do it in another pipeline
if existing_write_keys:
delete_pipeline = self._redis.pipeline(transaction=False)
for old_key in existing_write_keys:
old_key_str = (
old_key.decode() if isinstance(old_key, bytes) else old_key
)
delete_pipeline.delete(old_key_str)
await delete_pipeline.execute()
# NOTE: We intentionally do NOT clean up old writes here.
# In the HITL (Human-in-the-Loop) flow, interrupt writes are saved via
# put_writes BEFORE the new checkpoint is saved. If we clean up writes
# when the checkpoint changes, we would delete the interrupt writes
# before they can be consumed when resuming.
#
# Writes are cleaned up in the following scenarios:
# 1. When delete_thread is called
# 2. When TTL expires (if configured)
# 3. When put_writes is called again for the same task/idx (overwrites)
#
# See Issue #133 for details on this bug fix.

return next_config

Expand Down Expand Up @@ -388,7 +365,7 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
)

# Single fetch gets everything inline - matching sync implementation
full_checkpoint_data = await self._redis.json().get(checkpoint_key) # type: ignore[misc]
full_checkpoint_data = await self._redis.json().get(checkpoint_key)
if not full_checkpoint_data or not isinstance(full_checkpoint_data, dict):
return None

Expand Down Expand Up @@ -505,7 +482,11 @@ async def aput_writes(
writes_objects.append(write_obj)

# Thread-level sorted set for write keys
thread_zset_key = f"write_keys_zset:{thread_id}:{checkpoint_ns}:shallow"
# Use to_storage_safe_str for consistent key naming
safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
thread_zset_key = (
f"write_keys_zset:{thread_id}:{safe_checkpoint_ns}:shallow"
)

# Collect all write keys
write_keys = []
Expand All @@ -529,7 +510,7 @@ async def aput_writes(

# Use thread-level sorted set
zadd_mapping = {key: idx for idx, key in enumerate(write_keys)}
pipeline.zadd(thread_zset_key, zadd_mapping)
pipeline.zadd(thread_zset_key, zadd_mapping) # type: ignore[arg-type]

# Apply TTL to registry key if configured
if self.ttl_config and "default_ttl" in self.ttl_config:
Expand Down Expand Up @@ -563,7 +544,7 @@ async def aget_channel_values(
)

# Single JSON.GET operation to retrieve checkpoint with inline channel_values
checkpoint_data = await self._redis.json().get(checkpoint_key, "$.checkpoint") # type: ignore[misc]
checkpoint_data = await self._redis.json().get(checkpoint_key, "$.checkpoint")

if not checkpoint_data:
return {}
Expand Down Expand Up @@ -631,7 +612,9 @@ async def _aload_pending_writes(
return []

# Use thread-level sorted set
thread_zset_key = f"write_keys_zset:{thread_id}:{checkpoint_ns}:shallow"
# Use to_storage_safe_str for consistent key naming
safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
thread_zset_key = f"write_keys_zset:{thread_id}:{safe_checkpoint_ns}:shallow"

try:
# Check if we have any writes in the thread sorted set
Expand Down
12 changes: 6 additions & 6 deletions langgraph/checkpoint/redis/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ def _load_writes_from_redis(self, write_key: str) -> List[Tuple[str, str, Any]]:
return []

writes = []
for write in result["writes"]: # type: ignore[call-overload]
for write in result["writes"]:
writes.append(
(
write["task_id"],
Expand Down Expand Up @@ -636,17 +636,17 @@ def put_writes(
# UPSERT case - only update specific fields
if key_exists:
# Update only channel, type, and blob fields
pipeline.set(key, "$.channel", write_obj["channel"])
pipeline.set(key, "$.type", write_obj["type"])
pipeline.set(key, "$.blob", write_obj["blob"])
pipeline.json().set(key, "$.channel", write_obj["channel"])
pipeline.json().set(key, "$.type", write_obj["type"])
pipeline.json().set(key, "$.blob", write_obj["blob"])
else:
# For new records, set the complete object
pipeline.set(key, "$", write_obj)
pipeline.json().set(key, "$", write_obj)
created_keys.append(key)
else:
# INSERT case - only insert if doesn't exist
if not key_exists:
pipeline.set(key, "$", write_obj)
pipeline.json().set(key, "$", write_obj)
created_keys.append(key)

pipeline.execute()
Expand Down
4 changes: 2 additions & 2 deletions langgraph/checkpoint/redis/key_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def register_write_keys_batch(
)
# Use index as score to maintain order
mapping = {key: idx for idx, key in enumerate(write_keys)}
self._redis.zadd(zset_key, mapping)
self._redis.zadd(zset_key, mapping) # type: ignore[arg-type]

def get_write_keys(
self, thread_id: str, checkpoint_ns: str, checkpoint_id: str
Expand Down Expand Up @@ -215,7 +215,7 @@ async def register_write_keys_batch(
thread_id, checkpoint_ns, checkpoint_id
)
mapping = {key: idx for idx, key in enumerate(write_keys)}
await self._redis.zadd(zset_key, mapping)
await self._redis.zadd(zset_key, mapping) # type: ignore[arg-type]

async def get_write_keys(
self, thread_id: str, checkpoint_ns: str, checkpoint_id: str
Expand Down
49 changes: 20 additions & 29 deletions langgraph/checkpoint/redis/shallow.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,42 +200,29 @@ def put(
thread_id, checkpoint_ns
)

# Get the previous checkpoint ID to clean up its writes
prev_checkpoint_data = self._redis.json().get(checkpoint_key)
prev_checkpoint_id = None
if prev_checkpoint_data and isinstance(prev_checkpoint_data, dict):
prev_checkpoint_id = prev_checkpoint_data.get("checkpoint_id")

with self._redis.pipeline(transaction=False) as pipeline:
pipeline.json().set(checkpoint_key, "$", checkpoint_data)

# If checkpoint changed, clean up old writes
if prev_checkpoint_id and prev_checkpoint_id != checkpoint["id"]:
# Clean up writes from the previous checkpoint
thread_write_registry_key = (
f"write_registry:{thread_id}:{checkpoint_ns}:shallow"
)

# Get all existing write keys and delete them
existing_write_keys = self._redis.zrange(
thread_write_registry_key, 0, -1
)
for old_key in existing_write_keys:
old_key_str = (
old_key.decode() if isinstance(old_key, bytes) else old_key
)
pipeline.delete(old_key_str)

# Clear the registry
pipeline.delete(thread_write_registry_key)

# Apply TTL if configured
if self.ttl_config and "default_ttl" in self.ttl_config:
ttl_seconds = int(self.ttl_config.get("default_ttl") * 60)
pipeline.expire(checkpoint_key, ttl_seconds)

pipeline.execute()

# NOTE: We intentionally do NOT clean up old writes here.
# In the HITL (Human-in-the-Loop) flow, interrupt writes are saved via
# put_writes BEFORE the new checkpoint is saved. If we clean up writes
# when the checkpoint changes, we would delete the interrupt writes
# before they can be consumed when resuming.
#
# Writes are cleaned up in the following scenarios:
# 1. When delete_thread is called
# 2. When TTL expires (if configured)
# 3. When put_writes is called again for the same task/idx (overwrites)
#
# See Issue #133 for details on this bug fix.

return next_config

def list(
Expand Down Expand Up @@ -501,8 +488,10 @@ def put_writes(
writes_objects.append(write_obj)

# THREAD-LEVEL REGISTRY: Only keep writes for the current checkpoint
# Use to_storage_safe_str for consistent key naming with delete_thread
safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
thread_write_registry_key = (
f"write_registry:{thread_id}:{checkpoint_ns}:shallow"
f"write_registry:{thread_id}:{safe_checkpoint_ns}:shallow"
)

# Collect all write keys
Expand All @@ -525,7 +514,7 @@ def put_writes(
# THREAD-LEVEL REGISTRY: Store write keys in thread-level sorted set
# These will be cleared when checkpoint changes
zadd_mapping = {key: idx for idx, key in enumerate(write_keys)}
pipeline.zadd(thread_write_registry_key, zadd_mapping)
pipeline.zadd(thread_write_registry_key, zadd_mapping) # type: ignore[arg-type]

# Note: We don't update has_writes on the checkpoint anymore
# because put_writes can be called before the checkpoint exists
Expand All @@ -550,8 +539,10 @@ def _load_pending_writes(

# Use thread-level registry that only contains current checkpoint writes
# All writes belong to the current checkpoint
# Use to_storage_safe_str for consistent key naming with delete_thread
safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
thread_write_registry_key = (
f"write_registry:{thread_id}:{checkpoint_ns}:shallow"
f"write_registry:{thread_id}:{safe_checkpoint_ns}:shallow"
)

# Get all write keys from the thread's registry (already sorted by index)
Expand Down
2 changes: 1 addition & 1 deletion langgraph/store/redis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ def _batch_search_ops(
if not isinstance(store_doc, dict):
try:
store_doc = json.loads(
store_doc # type: ignore[arg-type]
store_doc
) # Attempt to parse if it's a JSON string
except (json.JSONDecodeError, TypeError):
logger.error(f"Failed to parse store_doc: {store_doc}")
Expand Down
4 changes: 2 additions & 2 deletions langgraph/store/redis/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ async def __aexit__(

# Close Redis connections if we own them
if self._owns_its_client:
await self._redis.aclose()
await self._redis.aclose() # type: ignore[attr-defined]
await self._redis.connection_pool.disconnect()

async def abatch(self, ops: Iterable[Op]) -> list[Result]:
Expand Down Expand Up @@ -781,7 +781,7 @@ async def _batch_search_ops(
)
result_map[store_key] = doc
# Fetch individually in cluster mode
store_doc_item = await self._redis.json().get(store_key) # type: ignore
store_doc_item = await self._redis.json().get(store_key)
store_docs.append(store_doc_item)
store_docs_raw = store_docs
else:
Expand Down
Loading