Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions langgraph/checkpoint/redis/jsonplus_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,10 @@ class JsonPlusRedisSerializer(JsonPlusSerializer):
]

def dumps(self, obj: Any) -> bytes:
"""Use orjson for simple objects, fallback to parent for complex objects."""
try:
# Fast path: Use orjson for JSON-serializable objects
return orjson.dumps(obj)
except TypeError:
# Complex objects (Send, etc.) need parent's msgpack serialization
return super().dumps(obj)
"""Use orjson for serialization with LangChain object support via default handler."""
# Use orjson with default handler for LangChain objects
# The _default method from parent class handles LangChain serialization
return orjson.dumps(obj, default=self._default)

def loads(self, data: bytes) -> Any:
"""Use orjson for JSON parsing with reviver support, fallback to parent for msgpack data."""
Expand All @@ -54,9 +51,15 @@ def loads(self, data: bytes) -> Any:
parsed = orjson.loads(data)
# Apply reviver for LangChain objects (lc format)
return self._revive_if_needed(parsed)
except orjson.JSONDecodeError:
# Fallback: Parent handles msgpack and other formats
return super().loads(data)
except (orjson.JSONDecodeError, TypeError):
# Fallback: Parent handles msgpack and other formats via loads_typed
# Attempt to detect type and use loads_typed
try:
# Try loading as msgpack via parent's loads_typed
return super().loads_typed(("msgpack", data))
except Exception:
# If that fails, try loading as json string
return super().loads_typed(("json", data))

def _revive_if_needed(self, obj: Any) -> Any:
"""Recursively apply reviver to handle LangChain serialized objects.
Expand Down Expand Up @@ -93,6 +96,7 @@ def dumps_typed(self, obj: Any) -> tuple[str, str]: # type: ignore[override]
if isinstance(obj, (bytes, bytearray)):
return "base64", base64.b64encode(obj).decode("utf-8")
else:
# All objects should be JSON-serializable (LangChain objects are pre-serialized)
return "json", self.dumps(obj).decode("utf-8")

def loads_typed(self, data: tuple[str, Union[str, bytes]]) -> Any:
Expand Down
196 changes: 196 additions & 0 deletions test_jsonplus_redis_serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
"""Standalone test to verify the JsonPlusRedisSerializer fix works.

This can be run directly without pytest infrastructure:
python test_fix_standalone.py
"""

from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langgraph.checkpoint.redis.jsonplus_redis import JsonPlusRedisSerializer


def test_human_message_serialization():
"""Test that HumanMessage can be serialized without TypeError."""
print("Testing HumanMessage serialization...")

serializer = JsonPlusRedisSerializer()
msg = HumanMessage(content="What is the weather?", id="msg-1")

try:
# This would raise TypeError before the fix
serialized = serializer.dumps(msg)
print(f" ✓ Serialized to {len(serialized)} bytes")

# Deserialize
deserialized = serializer.loads(serialized)
assert isinstance(deserialized, HumanMessage)
assert deserialized.content == "What is the weather?"
assert deserialized.id == "msg-1"
print(f" ✓ Deserialized correctly: {deserialized.content}")

return True
except TypeError as e:
print(f" ✗ FAILED: {e}")
return False


def test_all_message_types():
"""Test all LangChain message types."""
print("\nTesting all message types...")

serializer = JsonPlusRedisSerializer()
messages = [
HumanMessage(content="Hello"),
AIMessage(content="Hi!"),
SystemMessage(content="System prompt"),
]

for msg in messages:
try:
serialized = serializer.dumps(msg)
deserialized = serializer.loads(serialized)
assert type(deserialized) == type(msg)
print(f" ✓ {type(msg).__name__} works")
except Exception as e:
print(f" ✗ {type(msg).__name__} FAILED: {e}")
return False

return True


def test_message_list():
"""Test list of messages (common pattern in LangGraph)."""
print("\nTesting message list...")

serializer = JsonPlusRedisSerializer()
messages = [
HumanMessage(content="Question 1"),
AIMessage(content="Answer 1"),
HumanMessage(content="Question 2"),
]

try:
serialized = serializer.dumps(messages)
deserialized = serializer.loads(serialized)

assert isinstance(deserialized, list)
assert len(deserialized) == 3
assert all(isinstance(m, (HumanMessage, AIMessage)) for m in deserialized)
print(f" ✓ List of {len(deserialized)} messages works")

return True
except Exception as e:
print(f" ✗ FAILED: {e}")
return False


def test_nested_structure():
"""Test nested structure with messages (realistic LangGraph state)."""
print("\nTesting nested structure with messages...")

serializer = JsonPlusRedisSerializer()
state = {
"messages": [
HumanMessage(content="Query"),
AIMessage(content="Response"),
],
"step": 1,
}

try:
serialized = serializer.dumps(state)
deserialized = serializer.loads(serialized)

assert "messages" in deserialized
assert len(deserialized["messages"]) == 2
assert isinstance(deserialized["messages"][0], HumanMessage)
assert isinstance(deserialized["messages"][1], AIMessage)
print(f" ✓ Nested structure works")

return True
except Exception as e:
print(f" ✗ FAILED: {e}")
return False


def test_dumps_typed():
"""Test dumps_typed (what checkpointer actually uses)."""
print("\nTesting dumps_typed...")

serializer = JsonPlusRedisSerializer()
msg = HumanMessage(content="Test", id="test-123")

try:
type_str, blob = serializer.dumps_typed(msg)
assert type_str == "json"
assert isinstance(blob, str)
print(f" ✓ dumps_typed returns: type='{type_str}', blob={len(blob)} chars")

deserialized = serializer.loads_typed((type_str, blob))
assert isinstance(deserialized, HumanMessage)
assert deserialized.content == "Test"
print(f" ✓ loads_typed works correctly")

return True
except Exception as e:
print(f" ✗ FAILED: {e}")
return False


def test_backwards_compatibility():
"""Test that regular objects still work."""
print("\nTesting backwards compatibility...")

serializer = JsonPlusRedisSerializer()
test_cases = [
("string", "hello"),
("int", 42),
("dict", {"key": "value"}),
("list", [1, 2, 3]),
]

for name, obj in test_cases:
try:
serialized = serializer.dumps(obj)
deserialized = serializer.loads(serialized)
assert deserialized == obj
print(f" ✓ {name} works")
except Exception as e:
print(f" ✗ {name} FAILED: {e}")
return False

return True


def main():
"""Run all tests."""
print("=" * 70)
print("JsonPlusRedisSerializer Fix Validation")
print("=" * 70)

tests = [
test_human_message_serialization,
test_all_message_types,
test_message_list,
test_nested_structure,
test_dumps_typed,
test_backwards_compatibility,
]

results = []
for test in tests:
results.append(test())

print("\n" + "=" * 70)
print(f"Results: {sum(results)}/{len(results)} tests passed")
print("=" * 70)

if all(results):
print("\n✅ ALL TESTS PASSED - Fix is working correctly!")
return 0
else:
print("\n❌ SOME TESTS FAILED - Fix may not be working")
return 1


if __name__ == "__main__":
exit(main())
Loading