Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion langgraph/store/redis/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ async def __aexit__(
) -> None:
"""Async context manager exit."""
# Cancel the background task created by AsyncBatchedBaseStore
if hasattr(self, "_task") and not self._task.done():
if hasattr(self, "_task") and self._task is not None and not self._task.done():
self._task.cancel()
try:
await self._task
Expand Down
128 changes: 49 additions & 79 deletions test_jsonplus_redis_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,16 @@ def test_human_message_serialization():
serializer = JsonPlusRedisSerializer()
msg = HumanMessage(content="What is the weather?", id="msg-1")

try:
# This would raise TypeError before the fix
serialized = serializer.dumps(msg)
print(f" ✓ Serialized to {len(serialized)} bytes")

# Deserialize
deserialized = serializer.loads(serialized)
assert isinstance(deserialized, HumanMessage)
assert deserialized.content == "What is the weather?"
assert deserialized.id == "msg-1"
print(f" ✓ Deserialized correctly: {deserialized.content}")
# This would raise TypeError before the fix
serialized = serializer.dumps(msg)
print(f" ✓ Serialized to {len(serialized)} bytes")

return True
except TypeError as e:
print(f" ✗ FAILED: {e}")
return False
# Deserialize
deserialized = serializer.loads(serialized)
assert isinstance(deserialized, HumanMessage)
assert deserialized.content == "What is the weather?"
assert deserialized.id == "msg-1"
print(f" ✓ Deserialized correctly: {deserialized.content}")


def test_all_message_types():
Expand All @@ -45,16 +39,10 @@ def test_all_message_types():
]

for msg in messages:
try:
serialized = serializer.dumps(msg)
deserialized = serializer.loads(serialized)
assert type(deserialized) == type(msg)
print(f" ✓ {type(msg).__name__} works")
except Exception as e:
print(f" ✗ {type(msg).__name__} FAILED: {e}")
return False

return True
serialized = serializer.dumps(msg)
deserialized = serializer.loads(serialized)
assert type(deserialized) == type(msg)
print(f" ✓ {type(msg).__name__} works")


def test_message_list():
Expand All @@ -68,19 +56,13 @@ def test_message_list():
HumanMessage(content="Question 2"),
]

try:
serialized = serializer.dumps(messages)
deserialized = serializer.loads(serialized)

assert isinstance(deserialized, list)
assert len(deserialized) == 3
assert all(isinstance(m, (HumanMessage, AIMessage)) for m in deserialized)
print(f" ✓ List of {len(deserialized)} messages works")
serialized = serializer.dumps(messages)
deserialized = serializer.loads(serialized)

return True
except Exception as e:
print(f" ✗ FAILED: {e}")
return False
assert isinstance(deserialized, list)
assert len(deserialized) == 3
assert all(isinstance(m, (HumanMessage, AIMessage)) for m in deserialized)
print(f" ✓ List of {len(deserialized)} messages works")


def test_nested_structure():
Expand All @@ -96,20 +78,14 @@ def test_nested_structure():
"step": 1,
}

try:
serialized = serializer.dumps(state)
deserialized = serializer.loads(serialized)

assert "messages" in deserialized
assert len(deserialized["messages"]) == 2
assert isinstance(deserialized["messages"][0], HumanMessage)
assert isinstance(deserialized["messages"][1], AIMessage)
print(f" ✓ Nested structure works")
serialized = serializer.dumps(state)
deserialized = serializer.loads(serialized)

return True
except Exception as e:
print(f" ✗ FAILED: {e}")
return False
assert "messages" in deserialized
assert len(deserialized["messages"]) == 2
assert isinstance(deserialized["messages"][0], HumanMessage)
assert isinstance(deserialized["messages"][1], AIMessage)
print(f" ✓ Nested structure works")


def test_dumps_typed():
Expand All @@ -119,21 +95,15 @@ def test_dumps_typed():
serializer = JsonPlusRedisSerializer()
msg = HumanMessage(content="Test", id="test-123")

try:
type_str, blob = serializer.dumps_typed(msg)
assert type_str == "json"
assert isinstance(blob, str)
print(f" ✓ dumps_typed returns: type='{type_str}', blob={len(blob)} chars")

deserialized = serializer.loads_typed((type_str, blob))
assert isinstance(deserialized, HumanMessage)
assert deserialized.content == "Test"
print(f" ✓ loads_typed works correctly")
type_str, blob = serializer.dumps_typed(msg)
assert type_str == "json"
assert isinstance(blob, str)
print(f" ✓ dumps_typed returns: type='{type_str}', blob={len(blob)} chars")

return True
except Exception as e:
print(f" ✗ FAILED: {e}")
return False
deserialized = serializer.loads_typed((type_str, blob))
assert isinstance(deserialized, HumanMessage)
assert deserialized.content == "Test"
print(f" ✓ loads_typed works correctly")


def test_backwards_compatibility():
Expand All @@ -149,16 +119,10 @@ def test_backwards_compatibility():
]

for name, obj in test_cases:
try:
serialized = serializer.dumps(obj)
deserialized = serializer.loads(serialized)
assert deserialized == obj
print(f" ✓ {name} works")
except Exception as e:
print(f" ✗ {name} FAILED: {e}")
return False

return True
serialized = serializer.dumps(obj)
deserialized = serializer.loads(serialized)
assert deserialized == obj
print(f" ✓ {name} works")


def main():
Expand All @@ -176,19 +140,25 @@ def main():
test_backwards_compatibility,
]

results = []
passed = 0
failed = 0
for test in tests:
results.append(test())
try:
test()
passed += 1
except Exception as e:
print(f" ✗ {test.__name__} FAILED: {e}")
failed += 1

print("\n" + "=" * 70)
print(f"Results: {sum(results)}/{len(results)} tests passed")
print(f"Results: {passed}/{len(tests)} tests passed")
print("=" * 70)

if all(results):
if failed == 0:
print("\n✅ ALL TESTS PASSED - Fix is working correctly!")
return 0
else:
print("\n❌ SOME TESTS FAILED - Fix may not be working")
print(f"\n❌ {failed} TESTS FAILED - Fix may not be working")
return 1


Expand Down
179 changes: 179 additions & 0 deletions tests/test_issue_116_async_blob_access.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
"""
Regression test for Issue #116: AsyncRedisSaver AttributeError when calling aget_state_history()

This test verifies that the async implementation correctly handles blob access
when using _abatch_load_pending_sends with the JSON path syntax ($.blob).

The bug manifested as:
AttributeError: 'Document' object has no attribute 'blob'

This was caused by a mismatch between:
1. The return_fields specification ("blob" instead of "$.blob")
2. The attribute access pattern (direct access d.blob instead of getattr(d, "$.blob", ...))

The fix aligns the async implementation with the sync version by:
1. Using "$.blob" in return_fields
2. Using getattr(doc, "$.blob", getattr(doc, "blob", b"")) for access
"""

from typing import Any, Dict
from unittest.mock import MagicMock

import pytest

from langgraph.checkpoint.redis.aio import AsyncRedisSaver


class MockDocument:
"""Mock document that simulates Redis JSON path attribute behavior."""

def __init__(self, data: Dict[str, Any]):
self.checkpoint_id = data.get("checkpoint_id", "")
self.type = data.get("type", "")
self.task_path = data.get("task_path", "")
self.task_id = data.get("task_id", "")
self.idx = data.get("idx", 0)
# When using "$.blob" in return_fields, Redis returns it as "$.blob" attribute
if "json_blob" in data:
setattr(self, "$.blob", data["json_blob"])


@pytest.mark.asyncio
async def test_abatch_load_pending_sends_with_json_path_blob(redis_url: str) -> None:
"""
Test that _abatch_load_pending_sends correctly handles $.blob JSON path attribute.

This is a unit test with mocked Redis responses that directly tests the bug fix.
Before the fix, accessing d.blob would raise AttributeError because Redis returns
the attribute as "$.blob" (not "blob") when you specify "$.blob" in return_fields.
"""
async with AsyncRedisSaver.from_conn_string(redis_url) as saver:
await saver.asetup()

# Create mock search result with documents using $.blob (JSON path syntax)
mock_search_result = MagicMock()
mock_search_result.docs = [
MockDocument(
{
"checkpoint_id": "checkpoint_1",
"type": "test_type1",
"task_path": "path1",
"task_id": "task1",
"idx": 0,
"json_blob": b"data1", # This becomes $.blob attribute
}
),
MockDocument(
{
"checkpoint_id": "checkpoint_1",
"type": "test_type2",
"task_path": "path2",
"task_id": "task2",
"idx": 1,
"json_blob": b"data2", # This becomes $.blob attribute
}
),
MockDocument(
{
"checkpoint_id": "checkpoint_2",
"type": "test_type3",
"task_path": "path3",
"task_id": "task3",
"idx": 0,
"json_blob": b"data3", # This becomes $.blob attribute
}
),
]

# Mock the search method to return our mock documents
original_search = saver.checkpoint_writes_index.search

async def mock_search(_: Any) -> MagicMock:
return mock_search_result

saver.checkpoint_writes_index.search = mock_search

try:
# Call the method that was failing before the fix
# This internally tries to access d.blob which would fail without the fix
result = await saver._abatch_load_pending_sends(
[
("test_thread", "test_ns", "checkpoint_1"),
("test_thread", "test_ns", "checkpoint_2"),
]
)

# Verify results are correctly extracted
assert ("test_thread", "test_ns", "checkpoint_1") in result
assert ("test_thread", "test_ns", "checkpoint_2") in result

# Verify the blob data was correctly accessed via $.blob
checkpoint_1_data = result[("test_thread", "test_ns", "checkpoint_1")]
assert len(checkpoint_1_data) == 2
assert checkpoint_1_data[0] == ("test_type1", b"data1")
assert checkpoint_1_data[1] == ("test_type2", b"data2")

checkpoint_2_data = result[("test_thread", "test_ns", "checkpoint_2")]
assert len(checkpoint_2_data) == 1
assert checkpoint_2_data[0] == ("test_type3", b"data3")

finally:
# Restore original search method
saver.checkpoint_writes_index.search = original_search


@pytest.mark.asyncio
async def test_abatch_load_pending_sends_handles_missing_blob(redis_url: str) -> None:
"""
Test that _abatch_load_pending_sends gracefully handles missing blob attributes.

This tests the fallback logic: getattr(doc, "$.blob", getattr(doc, "blob", b""))
"""
async with AsyncRedisSaver.from_conn_string(redis_url) as saver:
await saver.asetup()

# Create mock documents - some with $.blob, some without
mock_search_result = MagicMock()
mock_search_result.docs = [
MockDocument(
{
"checkpoint_id": "checkpoint_1",
"type": "test_type1",
"task_path": "p1",
"task_id": "t1",
"idx": 0,
"json_blob": b"data1",
}
),
MockDocument(
{
"checkpoint_id": "checkpoint_1",
"type": "test_type2",
"task_path": "p2",
"task_id": "t2",
"idx": 1,
# No json_blob - this simulates missing $.blob attribute
}
),
]

original_search = saver.checkpoint_writes_index.search

async def mock_search(_: Any) -> MagicMock:
return mock_search_result

saver.checkpoint_writes_index.search = mock_search

try:
result = await saver._abatch_load_pending_sends(
[("test_thread", "test_ns", "checkpoint_1")]
)

# Should handle the missing blob gracefully with empty bytes fallback
checkpoint_data = result[("test_thread", "test_ns", "checkpoint_1")]
assert len(checkpoint_data) == 2
assert checkpoint_data[0] == ("test_type1", b"data1")
assert checkpoint_data[1] == ("test_type2", b"") # Fallback to b""

finally:
saver.checkpoint_writes_index.search = original_search
1 change: 1 addition & 0 deletions tests/test_jsonplus_serializer_default_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import pytest
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage

from langgraph.checkpoint.redis.jsonplus_redis import JsonPlusRedisSerializer


Expand Down