diff --git a/.gitignore b/.gitignore index a882340..dcbf3dc 100644 --- a/.gitignore +++ b/.gitignore @@ -235,3 +235,5 @@ libs/redis/docs/.Trash* .claude TASK_MEMORY.md *.code-workspace + +augment*.md diff --git a/agent_memory_server/cli.py b/agent_memory_server/cli.py index 335fb4d..32a27f3 100644 --- a/agent_memory_server/cli.py +++ b/agent_memory_server/cli.py @@ -94,6 +94,175 @@ async def run_migrations(): click.echo("Memory migrations completed successfully.") +@cli.command() +@click.option( + "--batch-size", + default=1000, + help="Number of keys to process in each batch", +) +@click.option( + "--dry-run", + is_flag=True, + help="Only count keys without migrating", +) +def migrate_working_memory(batch_size: int, dry_run: bool): + """ + Migrate working memory keys from string format to JSON format. + + This command migrates all working memory keys stored in the old string + format (JSON serialized as a string) to the new native Redis JSON format. + + Use --dry-run to see how many keys need migration without making changes. + """ + import asyncio + import time + + from agent_memory_server.utils.keys import Keys + from agent_memory_server.working_memory import ( + set_migration_complete, + ) + + configure_logging() + + async def run_migration(): + import json as json_module + + redis = await get_redis_conn() + + # Scan for string keys only using _type filter (much faster) + string_keys = [] + cursor = 0 + pattern = Keys.working_memory_key("*") + + click.echo("Scanning for working memory keys (string type only)...") + scan_start = time.time() + + while True: + # Use _type="string" to only get string keys directly + cursor, keys = await redis.scan( + cursor, match=pattern, count=batch_size, _type="string" + ) + + if keys: + string_keys.extend(keys) + + if cursor == 0: + break + + scan_time = time.time() - scan_start + + click.echo(f"Scan completed in {scan_time:.2f}s") + click.echo(f" String format (need migration): {len(string_keys)}") + + if not string_keys: + click.echo("\nNo keys need migration. All done!") + # Mark migration as complete + await set_migration_complete(redis) + return + + if dry_run: + click.echo("\n--dry-run specified, no changes made.") + return + + # Migrate keys in batches using pipeline + click.echo(f"\nMigrating {len(string_keys)} keys...") + migrate_start = time.time() + migrated = 0 + errors = 0 + + # Process in batches + for batch_start in range(0, len(string_keys), batch_size): + batch_keys = string_keys[batch_start : batch_start + batch_size] + + # Read all string data and TTLs in a pipeline + read_pipe = redis.pipeline() + for key in batch_keys: + read_pipe.get(key) + read_pipe.ttl(key) + results = await read_pipe.execute() + + # Parse results (alternating: data, ttl, data, ttl, ...) + migrations = [] # List of (key, data, ttl) tuples + for i, key in enumerate(batch_keys): + string_data = results[i * 2] + ttl = results[i * 2 + 1] + + if string_data is None: + continue + + try: + if isinstance(string_data, bytes): + string_data = string_data.decode("utf-8") + data = json_module.loads(string_data) + migrations.append((key, data, ttl)) + except Exception as e: + errors += 1 + logger.error(f"Failed to parse key {key}: {e}") + + # Execute migrations in a pipeline (delete + json.set + expire if needed) + if migrations: + write_pipe = redis.pipeline() + for key, data, ttl in migrations: + write_pipe.delete(key) + write_pipe.json().set(key, "$", data) + if ttl > 0: + write_pipe.expire(key, ttl) + + try: + await write_pipe.execute() + migrated += len(migrations) + except Exception as e: + # If batch fails, try one by one + logger.warning( + f"Batch migration failed, retrying individually: {e}" + ) + for key, data, ttl in migrations: + try: + await redis.delete(key) + await redis.json().set(key, "$", data) + if ttl > 0: + await redis.expire(key, ttl) + migrated += 1 + except Exception as e2: + errors += 1 + logger.error(f"Failed to migrate key {key}: {e2}") + + # Progress update + total_processed = batch_start + len(batch_keys) + if total_processed % 10000 == 0 or total_processed == len(string_keys): + elapsed = time.time() - migrate_start + rate = migrated / elapsed if elapsed > 0 else 0 + remaining = len(string_keys) - total_processed + eta = remaining / rate if rate > 0 else 0 + click.echo( + f" Migrated {migrated}/{len(string_keys)} " + f"({rate:.0f} keys/sec, ETA: {eta:.0f}s)" + ) + + migrate_time = time.time() - migrate_start + rate = migrated / migrate_time if migrate_time > 0 else 0 + + click.echo(f"\nMigration completed in {migrate_time:.2f}s") + click.echo(f" Migrated: {migrated}") + click.echo(f" Errors: {errors}") + click.echo(f" Rate: {rate:.0f} keys/sec") + + if errors == 0: + # Mark migration as complete + await set_migration_complete(redis) + click.echo("\nMigration status set to complete.") + click.echo( + "\nšŸ’” Tip: Set WORKING_MEMORY_MIGRATION_COMPLETE=true to skip " + "startup checks permanently." + ) + else: + click.echo( + "\nMigration completed with errors. " "Run again to retry failed keys." + ) + + asyncio.run(run_migration()) + + @cli.command() @click.option("--port", default=settings.port, help="Port to run the server on") @click.option("--host", default="0.0.0.0", help="Host to run the server on") diff --git a/agent_memory_server/config.py b/agent_memory_server/config.py index d1e065c..61535f9 100644 --- a/agent_memory_server/config.py +++ b/agent_memory_server/config.py @@ -259,6 +259,13 @@ class Settings(BaseSettings): 0.7 # Fraction of context window that triggers summarization ) + # Working memory migration settings + # Set to True to skip backward compatibility checks for old string-format keys. + # Use this after running 'agent-memory migrate-working-memory' or for fresh installs. + # When True, the server assumes all working memory keys are in JSON format, + # skipping the startup scan and per-read type checks for better performance. + working_memory_migration_complete: bool = False + # Query optimization settings query_optimization_prompt_template: str = """Transform this natural language query into an optimized version for semantic search. The goal is to make it more effective for finding semantically similar content while preserving the original intent. diff --git a/agent_memory_server/main.py b/agent_memory_server/main.py index 7eaf93c..363f7cb 100644 --- a/agent_memory_server/main.py +++ b/agent_memory_server/main.py @@ -74,9 +74,13 @@ async def lifespan(app: FastAPI): "Long-term memory requires OpenAI for embeddings, but OpenAI API key is not set" ) - # Set up Redis connection if long-term memory is enabled - if settings.long_term_memory: - await get_redis_conn() + # Set up Redis connection and check working memory migration status + redis_conn = await get_redis_conn() + + # Check if any working memory keys need migration from string to JSON format + from agent_memory_server.working_memory import check_and_set_migration_status + + await check_and_set_migration_status(redis_conn) # Initialize Docket for background tasks if enabled if settings.use_docket: diff --git a/agent_memory_server/working_memory.py b/agent_memory_server/working_memory.py index c75b637..886bd79 100644 --- a/agent_memory_server/working_memory.py +++ b/agent_memory_server/working_memory.py @@ -7,6 +7,7 @@ from redis.asyncio import Redis +from agent_memory_server.config import settings from agent_memory_server.models import ( MemoryMessage, MemoryRecord, @@ -19,12 +20,208 @@ logger = logging.getLogger(__name__) +# Redis keys for migration status (shared across workers, persists across restarts) +MIGRATION_STATUS_KEY = "working_memory:migration:complete" +MIGRATION_REMAINING_KEY = "working_memory:migration:remaining" -def json_datetime_handler(obj): - """JSON serializer for datetime objects.""" - if isinstance(obj, datetime): - return obj.isoformat() - raise TypeError(f"Object of type {type(obj)} is not JSON serializable") + +async def check_and_set_migration_status(redis_client: Redis | None = None) -> bool: + """ + Check if any working memory keys are still in old string format. + Stores migration status in Redis for cross-worker consistency. + + If WORKING_MEMORY_MIGRATION_COMPLETE=true is set, skips the scan entirely + and assumes all keys are in JSON format. + + Args: + redis_client: Optional Redis client + + Returns: + True if all keys are migrated (or no keys exist), False if string keys remain + """ + # If env variable is set, skip the scan entirely + if settings.working_memory_migration_complete: + logger.info( + "WORKING_MEMORY_MIGRATION_COMPLETE=true, skipping backward compatibility checks." + ) + return True + + if not redis_client: + redis_client = await get_redis_conn() + + # Check if migration status is already stored in Redis + status = await redis_client.get(MIGRATION_STATUS_KEY) + if status: + if isinstance(status, bytes): + status = status.decode("utf-8") + if status == "true": + logger.info( + "Migration status in Redis indicates complete. Skipping type checks." + ) + return True + + # Scan for working_memory:* keys of type STRING only + # This is much faster than scanning all keys and calling TYPE on each + cursor = 0 + string_keys_found = 0 + + try: + while True: + # Use _type="string" to only get string keys directly + cursor, keys = await redis_client.scan( + cursor=cursor, match="working_memory:*", count=1000, _type="string" + ) + + if keys: + # Filter out migration status keys (they're also strings) + keys = [ + k + for k in keys + if (k.decode("utf-8") if isinstance(k, bytes) else k) + not in (MIGRATION_STATUS_KEY, MIGRATION_REMAINING_KEY) + ] + string_keys_found += len(keys) + + if cursor == 0: + break + + if string_keys_found > 0: + # Store the count in Redis for atomic decrement during lazy migration + await redis_client.set(MIGRATION_REMAINING_KEY, str(string_keys_found)) + logger.info( + f"Found {string_keys_found} working memory keys in old string format. " + "Lazy migration enabled." + ) + return False + + # No string keys found - mark as complete in Redis + await redis_client.set(MIGRATION_STATUS_KEY, "true") + await redis_client.delete(MIGRATION_REMAINING_KEY) + + logger.info( + "No working memory string keys found. Skipping backward compatibility checks." + ) + return True + except Exception as e: + logger.error(f"Failed to check migration status: {e}") + return False + + +async def _decrement_remaining_count(redis_client: Redis) -> None: + """ + Atomically decrement the remaining string key counter. + When it reaches 0, mark migration as complete. + """ + try: + remaining = await redis_client.decr(MIGRATION_REMAINING_KEY) + if remaining <= 0: + await redis_client.set(MIGRATION_STATUS_KEY, "true") + await redis_client.delete(MIGRATION_REMAINING_KEY) + logger.info("All working memory keys have been migrated to JSON format.") + except Exception as e: + # Non-fatal - migration still works, just won't auto-complete + logger.warning(f"Failed to decrement migration counter: {e}") + + +async def is_migration_complete(redis_client: Redis | None = None) -> bool: + """Check if migration is complete.""" + if settings.working_memory_migration_complete: + return True + + if not redis_client: + redis_client = await get_redis_conn() + + status = await redis_client.get(MIGRATION_STATUS_KEY) + if status: + if isinstance(status, bytes): + status = status.decode("utf-8") + return status == "true" + return False + + +async def get_remaining_string_keys(redis_client: Redis | None = None) -> int: + """Get the count of remaining string keys (for testing/monitoring).""" + if not redis_client: + redis_client = await get_redis_conn() + + remaining = await redis_client.get(MIGRATION_REMAINING_KEY) + if remaining: + if isinstance(remaining, bytes): + remaining = remaining.decode("utf-8") + return int(remaining) + return 0 + + +async def reset_migration_status(redis_client: Redis | None = None) -> None: + """Reset migration status (for testing purposes).""" + if not redis_client: + redis_client = await get_redis_conn() + + await redis_client.delete(MIGRATION_STATUS_KEY, MIGRATION_REMAINING_KEY) + + +async def set_migration_complete(redis_client: Redis | None = None) -> None: + """Mark migration as complete (called by migration script).""" + if not redis_client: + redis_client = await get_redis_conn() + + await redis_client.set(MIGRATION_STATUS_KEY, "true") + await redis_client.delete(MIGRATION_REMAINING_KEY) + logger.info("Working memory migration marked as complete.") + + +async def _migrate_string_to_json( + redis_client: Redis, + key: str, + string_data: str, +) -> dict: + """ + Migrate working memory from old string format to new JSON format. + + Args: + redis_client: Redis client + key: The Redis key + string_data: The JSON string data from the old format + + Returns: + The parsed dict data + """ + try: + data = json.loads(string_data) + logger.info(f"Migrating working memory key {key} from string to JSON format") + + # Atomically migrate the key from string to JSON using a Lua script + # The script: get TTL, get value, delete, set as JSON, restore TTL if > 0 + lua_script = """ + local key = KEYS[1] + if redis.call('TYPE', key).ok == 'string' then + local ttl = redis.call('TTL', key) + local val = redis.call('GET', key) + redis.call('DEL', key) + redis.call('JSON.SET', key, '$', ARGV[1]) + if ttl > 0 then + redis.call('EXPIRE', key, ttl) + end + return val + else + return nil + end + """ + # Pass the JSON string as ARGV[1] + await redis_client.eval(lua_script, 1, key, json.dumps(data)) + + logger.info(f"Successfully migrated working memory key {key} to JSON format") + + # Atomically decrement the remaining counter + await _decrement_remaining_count(redis_client) + + return data + except json.JSONDecodeError as e: + logger.error(f"Failed to parse string data for key {key}: {e}") + raise + except Exception as e: + logger.error(f"Failed to migrate working memory key {key}: {e}") + raise async def list_sessions( @@ -103,8 +300,35 @@ async def get_working_memory( ) try: - data = await redis_client.get(key) - if not data: + working_memory_data = None + + # Check migration status (uses Redis, shared across workers) + migration_complete = await is_migration_complete(redis_client) + + if migration_complete: + # Fast path: all keys are already in JSON format + working_memory_data = await redis_client.json().get(key) + else: + # Slow path: check key type to determine storage format + key_type = await redis_client.type(key) + if isinstance(key_type, bytes): + key_type = key_type.decode("utf-8") + + if key_type == "ReJSON-RL": + # New JSON format + working_memory_data = await redis_client.json().get(key) + elif key_type == "string": + # Old string format - migrate to JSON + string_data = await redis_client.get(key) + if string_data: + if isinstance(string_data, bytes): + string_data = string_data.decode("utf-8") + working_memory_data = await _migrate_string_to_json( + redis_client, key, string_data + ) + # If key_type is "none", the key doesn't exist - working_memory_data stays None + + if not working_memory_data: logger.debug( f"No working memory found for parameters: {session_id}, {user_id}, {namespace}" ) @@ -125,9 +349,6 @@ async def get_working_memory( return None - # Parse the JSON data - working_memory_data = json.loads(data) - # Convert memory records back to MemoryRecord objects memories = [] for memory_data in working_memory_data.get("memories", []): @@ -233,21 +454,16 @@ async def set_working_memory( } try: + # Use Redis native JSON storage + await redis_client.json().set(key, "$", data) + if working_memory.ttl_seconds is not None: - # Store with TTL - await redis_client.setex( - key, - working_memory.ttl_seconds, - json.dumps(data, default=json_datetime_handler), - ) + # Set TTL separately for JSON keys + await redis_client.expire(key, working_memory.ttl_seconds) logger.info( f"Set working memory for session {working_memory.session_id} with TTL {working_memory.ttl_seconds}s" ) else: - await redis_client.set( - key, - json.dumps(data, default=json_datetime_handler), - ) logger.info( f"Set working memory for session {working_memory.session_id} with no TTL" ) diff --git a/tests/benchmarks/test_migration_benchmark.py b/tests/benchmarks/test_migration_benchmark.py new file mode 100644 index 0000000..deb960b --- /dev/null +++ b/tests/benchmarks/test_migration_benchmark.py @@ -0,0 +1,378 @@ +""" +Benchmark tests for working memory migration from string to JSON format. + +Run with: + uv run pytest tests/benchmarks/test_migration_benchmark.py -v -s --benchmark + +Use environment variables to control scale: + BENCHMARK_KEY_COUNT=10000 uv run pytest tests/benchmarks/test_migration_benchmark.py -v -s +""" + +import json +import os +import time + +import pytest + +from agent_memory_server.utils.keys import Keys +from agent_memory_server.working_memory import ( + check_and_set_migration_status, + get_working_memory, + reset_migration_status, +) + + +# Default to 1000 keys for CI, can override with env var +DEFAULT_KEY_COUNT = 1000 +KEY_COUNT = int(os.environ.get("BENCHMARK_KEY_COUNT", DEFAULT_KEY_COUNT)) + + +def create_old_format_data(session_id: str, namespace: str) -> dict: + """Create old-format working memory data.""" + return { + "messages": [ + { + "id": f"msg-{session_id}", + "role": "user", + "content": f"Hello from session {session_id}", + "created_at": "2024-01-01T00:00:00+00:00", + } + ], + "memories": [], + "session_id": session_id, + "namespace": namespace, + "context": None, + "user_id": None, + "tokens": 10, + "ttl_seconds": None, + "data": {}, + "long_term_memory_strategy": {"strategy": "discrete"}, + "last_accessed": 1704067200, + "created_at": 1704067200, + "updated_at": 1704067200, + } + + +@pytest.fixture +async def cleanup_working_memory_keys(async_redis_client): + """Clean up all working memory keys before and after test.""" + + async def cleanup(): + cursor = 0 + deleted = 0 + while True: + cursor, keys = await async_redis_client.scan( + cursor=cursor, match="working_memory:*", count=1000 + ) + if keys: + await async_redis_client.delete(*keys) + deleted += len(keys) + if cursor == 0: + break + return deleted + + # Clean before + await cleanup() + await reset_migration_status(async_redis_client) + + yield + + # Clean after + await cleanup() + await reset_migration_status(async_redis_client) + + +@pytest.mark.benchmark +class TestMigrationBenchmark: + """Benchmark tests for migration performance.""" + + @pytest.mark.asyncio + async def test_startup_scan_performance( + self, async_redis_client, cleanup_working_memory_keys + ): + """Benchmark: How long does startup scan take with many string keys? + + With early-exit optimization, this should be very fast - it stops + as soon as it finds the first string key. + """ + namespace = "benchmark" + + # Create string keys in batches using pipeline + print(f"\nšŸ“Š Creating {KEY_COUNT:,} string keys...") + start = time.perf_counter() + + batch_size = 1000 + for batch_start in range(0, KEY_COUNT, batch_size): + pipe = async_redis_client.pipeline() + for i in range(batch_start, min(batch_start + batch_size, KEY_COUNT)): + key = Keys.working_memory_key( + session_id=f"bench-session-{i}", namespace=namespace + ) + data = create_old_format_data(f"bench-session-{i}", namespace) + pipe.set(key, json.dumps(data)) + await pipe.execute() + + if (batch_start + batch_size) % 10000 == 0: + print(f" Created {batch_start + batch_size:,} keys...") + + creation_time = time.perf_counter() - start + print(f"āœ… Created {KEY_COUNT:,} keys in {creation_time:.2f}s") + + # Benchmark startup scan (with early exit) + print("\nšŸ“Š Benchmarking startup scan (early exit on first string key)...") + await reset_migration_status(async_redis_client) + + start = time.perf_counter() + result = await check_and_set_migration_status(async_redis_client) + scan_time = time.perf_counter() - start + + print(f"āœ… Startup scan completed in {scan_time:.4f}s (early exit)") + print(f" Result: migration_complete={result}") + + assert result is False # Should find string keys + + @pytest.mark.asyncio + async def test_lazy_migration_performance( + self, async_redis_client, cleanup_working_memory_keys + ): + """Benchmark: How long does lazy migration take per key? + + This tests the CURRENT implementation which re-scans after each migration. + With N keys, this is O(N²) - very slow for large N. + """ + namespace = "benchmark" + # Use smaller count for this test since it's O(N²) + key_count = min(KEY_COUNT, 100) + + # Create string keys + print(f"\nšŸ“Š Creating {key_count} string keys for lazy migration test...") + for i in range(key_count): + key = Keys.working_memory_key( + session_id=f"lazy-session-{i}", namespace=namespace + ) + data = create_old_format_data(f"lazy-session-{i}", namespace) + await async_redis_client.set(key, json.dumps(data)) + + # Set migration status + await reset_migration_status(async_redis_client) + await check_and_set_migration_status(async_redis_client) + + # Benchmark lazy migration (read each key, triggering migration) + print(f"\nšŸ“Š Benchmarking lazy migration of {key_count} keys...") + start = time.perf_counter() + + for i in range(key_count): + await get_working_memory( + session_id=f"lazy-session-{i}", + namespace=namespace, + redis_client=async_redis_client, + ) + + migration_time = time.perf_counter() - start + print(f"āœ… Lazy migration completed in {migration_time:.2f}s") + print(f" Average per key: {migration_time / key_count * 1000:.2f}ms") + print(f" Keys/second: {key_count / migration_time:,.0f}") + + @pytest.mark.asyncio + async def test_post_migration_read_performance( + self, async_redis_client, cleanup_working_memory_keys + ): + """Benchmark: Read performance after migration is complete (fast path).""" + namespace = "benchmark" + key_count = min(KEY_COUNT, 1000) + + # Create JSON keys directly (simulating post-migration state) + print(f"\nšŸ“Š Creating {key_count} JSON keys...") + batch_size = 100 + for batch_start in range(0, key_count, batch_size): + pipe = async_redis_client.pipeline() + for i in range(batch_start, min(batch_start + batch_size, key_count)): + key = Keys.working_memory_key( + session_id=f"json-session-{i}", namespace=namespace + ) + data = create_old_format_data(f"json-session-{i}", namespace) + pipe.json().set(key, "$", data) + await pipe.execute() + + # Set migration as complete + await reset_migration_status(async_redis_client) + await check_and_set_migration_status(async_redis_client) + + # Benchmark reads (should use fast path) + print(f"\nšŸ“Š Benchmarking fast-path reads of {key_count} keys...") + start = time.perf_counter() + + for i in range(key_count): + await get_working_memory( + session_id=f"json-session-{i}", + namespace=namespace, + redis_client=async_redis_client, + ) + + read_time = time.perf_counter() - start + print(f"āœ… Fast-path reads completed in {read_time:.2f}s") + print(f" Average per key: {read_time / key_count * 1000:.2f}ms") + print(f" Keys/second: {key_count / read_time:,.0f}") + + @pytest.mark.asyncio + async def test_worst_case_single_string_key_at_end( + self, async_redis_client, cleanup_working_memory_keys + ): + """Benchmark: Worst case - 1M JSON keys with 1 string key (scanned last). + + This tests the scenario where early-exit doesn't help because + the string key is found at the very end of the scan. + """ + namespace = "benchmark" + + # Create JSON keys in batches using pipeline + print(f"\nšŸ“Š Creating {KEY_COUNT:,} JSON keys + 1 string key...") + start = time.perf_counter() + + batch_size = 1000 + for batch_start in range(0, KEY_COUNT, batch_size): + pipe = async_redis_client.pipeline() + for i in range(batch_start, min(batch_start + batch_size, KEY_COUNT)): + key = Keys.working_memory_key( + session_id=f"json-session-{i}", namespace=namespace + ) + data = create_old_format_data(f"json-session-{i}", namespace) + pipe.json().set(key, "$", data) + await pipe.execute() + + if (batch_start + batch_size) % 100000 == 0: + print(f" Created {batch_start + batch_size:,} JSON keys...") + + # Add ONE string key with a session ID that sorts last alphabetically + # Using 'zzz' prefix to make it likely to be scanned last + string_key = Keys.working_memory_key( + session_id="zzz-string-key-last", namespace=namespace + ) + string_data = create_old_format_data("zzz-string-key-last", namespace) + await async_redis_client.set(string_key, json.dumps(string_data)) + + creation_time = time.perf_counter() - start + print( + f"āœ… Created {KEY_COUNT:,} JSON keys + 1 string key in {creation_time:.2f}s" + ) + + # Benchmark startup scan - must scan all keys to find the string one + print("\nšŸ“Š Benchmarking startup scan (worst case - string key at end)...") + await reset_migration_status(async_redis_client) + + start = time.perf_counter() + result = await check_and_set_migration_status(async_redis_client) + scan_time = time.perf_counter() - start + + print(f"āœ… Startup scan completed in {scan_time:.2f}s") + print(f" Result: migration_complete={result}") + print(f" Keys scanned per second: {KEY_COUNT / scan_time:,.0f}") + + # Should find the string key + assert result is False + + @pytest.mark.asyncio + async def test_migration_script_performance( + self, async_redis_client, cleanup_working_memory_keys + ): + """Benchmark: Migration script performance with pipelined operations.""" + namespace = "benchmark" + + # Create string keys in batches using pipeline + print(f"\nšŸ“Š Creating {KEY_COUNT:,} string keys for migration...") + start = time.perf_counter() + + batch_size = 1000 + for batch_start in range(0, KEY_COUNT, batch_size): + pipe = async_redis_client.pipeline() + for i in range(batch_start, min(batch_start + batch_size, KEY_COUNT)): + key = Keys.working_memory_key( + session_id=f"string-session-{i}", namespace=namespace + ) + data = create_old_format_data(f"string-session-{i}", namespace) + pipe.set(key, json.dumps(data)) + await pipe.execute() + + if (batch_start + batch_size) % 100000 == 0: + print(f" Created {batch_start + batch_size:,} string keys...") + + creation_time = time.perf_counter() - start + print(f"āœ… Created {KEY_COUNT:,} string keys in {creation_time:.2f}s") + + # Benchmark migration (simulating what the CLI does) + print("\nšŸ“Š Benchmarking pipelined migration...") + migrate_start = time.perf_counter() + + # Scan and collect string keys + string_keys = [] + cursor = 0 + while True: + cursor, keys = await async_redis_client.scan( + cursor, match="working_memory:*", count=1000 + ) + if keys: + pipe = async_redis_client.pipeline() + for key in keys: + pipe.type(key) + types = await pipe.execute() + + for key, key_type in zip(keys, types, strict=False): + if isinstance(key_type, bytes): + key_type = key_type.decode("utf-8") + if key_type == "string": + string_keys.append(key) + + if cursor == 0: + break + + scan_time = time.perf_counter() - migrate_start + print( + f" Scan completed in {scan_time:.2f}s ({len(string_keys):,} string keys)" + ) + + # Migrate in batches + migrated = 0 + for batch_start in range(0, len(string_keys), batch_size): + batch_keys = string_keys[batch_start : batch_start + batch_size] + + # Read all string data + read_pipe = async_redis_client.pipeline() + for key in batch_keys: + read_pipe.get(key) + string_data_list = await read_pipe.execute() + + # Parse and migrate + write_pipe = async_redis_client.pipeline() + for key, string_data in zip(batch_keys, string_data_list, strict=False): + if string_data is None: + continue + if isinstance(string_data, bytes): + string_data = string_data.decode("utf-8") + data = json.loads(string_data) + write_pipe.delete(key) + write_pipe.json().set(key, "$", data) + + await write_pipe.execute() + migrated += len(batch_keys) + + if migrated % 100000 == 0: + elapsed = time.perf_counter() - migrate_start + print( + f" Migrated {migrated:,} keys ({migrated / elapsed:,.0f} keys/sec)" + ) + + migrate_time = time.perf_counter() - migrate_start + rate = migrated / migrate_time + + print(f"āœ… Migration completed in {migrate_time:.2f}s") + print(f" Migrated: {migrated:,}") + print(f" Rate: {rate:,.0f} keys/sec") + + # Verify migration + sample_key = Keys.working_memory_key( + session_id="string-session-0", namespace=namespace + ) + key_type = await async_redis_client.type(sample_key) + if isinstance(key_type, bytes): + key_type = key_type.decode("utf-8") + assert key_type == "ReJSON-RL", f"Expected ReJSON-RL, got {key_type}" diff --git a/tests/test_working_memory.py b/tests/test_working_memory.py index f7db305..b7f6884 100644 --- a/tests/test_working_memory.py +++ b/tests/test_working_memory.py @@ -339,3 +339,368 @@ async def test_working_memory_ttl_update_preserves_ttl(self, async_redis_client) ) ttl = await async_redis_client.ttl(key) assert 0 < ttl <= ttl_seconds + + @pytest.mark.asyncio + async def test_backward_compatibility_string_to_json_migration( + self, async_redis_client + ): + """Test that old string-format working memory is migrated to JSON format on read.""" + import json + + from agent_memory_server.working_memory import ( + check_and_set_migration_status, + is_migration_complete, + reset_migration_status, + ) + + # Reset migration status to ensure lazy migration is active + await reset_migration_status(async_redis_client) + assert not await is_migration_complete(async_redis_client) + + session_id = "test-migration-session" + namespace = "test-namespace" + + # Create old-format data (stringified JSON) + old_format_data = { + "messages": [ + { + "id": "msg-1", + "role": "user", + "content": "Hello", + "created_at": "2024-01-01T00:00:00+00:00", + } + ], + "memories": [ + { + "id": "mem-1", + "text": "User prefers dark mode", + "memory_type": "semantic", + } + ], + "context": None, + "user_id": "user123", + "tokens": 10, + "session_id": session_id, + "namespace": namespace, + "ttl_seconds": None, + "data": {}, + "long_term_memory_strategy": {"strategy": "discrete"}, + "last_accessed": 1704067200, + "created_at": 1704067200, + "updated_at": 1704067200, + } + + # Store as old string format directly + key = Keys.working_memory_key(session_id=session_id, namespace=namespace) + await async_redis_client.set(key, json.dumps(old_format_data)) + + # Verify it's stored as string (not JSON) + key_type = await async_redis_client.type(key) + # Redis returns bytes, decode if needed + if isinstance(key_type, bytes): + key_type = key_type.decode("utf-8") + assert key_type == "string" + + # Run check to set up the counter (finds 1 string key) + await check_and_set_migration_status(async_redis_client) + assert not await is_migration_complete(async_redis_client) + + # Now read using get_working_memory - should trigger migration + retrieved_mem = await get_working_memory( + session_id=session_id, + namespace=namespace, + redis_client=async_redis_client, + ) + + # Verify data was retrieved correctly + assert retrieved_mem is not None + assert retrieved_mem.session_id == session_id + assert retrieved_mem.namespace == namespace + assert len(retrieved_mem.messages) == 1 + assert retrieved_mem.messages[0].role == "user" + assert retrieved_mem.messages[0].content == "Hello" + assert len(retrieved_mem.memories) == 1 + assert retrieved_mem.memories[0].text == "User prefers dark mode" + + # Verify the key was migrated to JSON format + key_type_after = await async_redis_client.type(key) + if isinstance(key_type_after, bytes): + key_type_after = key_type_after.decode("utf-8") + assert key_type_after == "ReJSON-RL" + + # Migration auto-completes when counter reaches 0 (only 1 key, now migrated) + assert await is_migration_complete(async_redis_client) + + # Verify we can read it again (now from JSON format, using fast path) + retrieved_again = await get_working_memory( + session_id=session_id, + namespace=namespace, + redis_client=async_redis_client, + ) + assert retrieved_again is not None + assert retrieved_again.session_id == session_id + + @pytest.mark.asyncio + async def test_migration_preserves_ttl(self, async_redis_client): + """Test that TTL is preserved when migrating from string to JSON format.""" + import json + + from agent_memory_server.working_memory import reset_migration_status + + # Reset migration status to ensure lazy migration is active + await reset_migration_status(async_redis_client) + + session_id = "test-ttl-migration-session" + namespace = "test-namespace" + ttl_seconds = 3600 # 1 hour + + # Create old-format data with TTL + old_format_data = { + "messages": [], + "memories": [], + "session_id": session_id, + "namespace": namespace, + "context": None, + "user_id": None, + "tokens": 0, + "ttl_seconds": ttl_seconds, + "data": {}, + "long_term_memory_strategy": {"strategy": "discrete"}, + "last_accessed": 1704067200, + "created_at": 1704067200, + "updated_at": 1704067200, + } + + # Store as old string format with TTL + key = Keys.working_memory_key(session_id=session_id, namespace=namespace) + await async_redis_client.set(key, json.dumps(old_format_data), ex=ttl_seconds) + + # Verify TTL is set + ttl_before = await async_redis_client.ttl(key) + assert ttl_before > 0 + assert ttl_before <= ttl_seconds + + # Trigger migration by reading + retrieved_mem = await get_working_memory( + session_id=session_id, + namespace=namespace, + redis_client=async_redis_client, + ) + assert retrieved_mem is not None + + # Verify key was migrated to JSON + key_type = await async_redis_client.type(key) + if isinstance(key_type, bytes): + key_type = key_type.decode("utf-8") + assert key_type == "ReJSON-RL" + + # Verify TTL was preserved + ttl_after = await async_redis_client.ttl(key) + assert ttl_after > 0 + # TTL should be close to original (within a few seconds of test execution) + assert ttl_after <= ttl_seconds + assert ttl_after >= ttl_before - 5 # Allow 5 seconds for test execution + + @pytest.mark.asyncio + async def test_check_and_set_migration_status_with_no_keys( + self, async_redis_client + ): + """Test migration status check when no working memory keys exist.""" + from agent_memory_server.working_memory import ( + check_and_set_migration_status, + is_migration_complete, + reset_migration_status, + ) + + # Reset to ensure clean state + await reset_migration_status(async_redis_client) + assert not await is_migration_complete(async_redis_client) + + # Check status with no keys - should mark as migrated (nothing to migrate) + result = await check_and_set_migration_status(async_redis_client) + assert result is True + assert await is_migration_complete(async_redis_client) + + @pytest.mark.asyncio + async def test_check_and_set_migration_status_with_json_keys_only( + self, async_redis_client + ): + """Test migration status check when only JSON keys exist.""" + from agent_memory_server.working_memory import ( + check_and_set_migration_status, + is_migration_complete, + reset_migration_status, + ) + + # Reset to ensure clean state + await reset_migration_status(async_redis_client) + + # Create a JSON key + session_id = "test-json-session" + namespace = "test-namespace" + memories = [ + MemoryRecord( + text="Test memory", + id="mem-1", + memory_type=MemoryTypeEnum.SEMANTIC, + ), + ] + working_mem = WorkingMemory( + memories=memories, + session_id=session_id, + namespace=namespace, + ) + await set_working_memory(working_mem, redis_client=async_redis_client) + + # Check status - should mark as migrated (only JSON keys) + result = await check_and_set_migration_status(async_redis_client) + assert result is True + assert await is_migration_complete(async_redis_client) + + @pytest.mark.asyncio + async def test_check_and_set_migration_status_with_string_keys( + self, async_redis_client + ): + """Test migration status check when string keys exist.""" + import json + + from agent_memory_server.working_memory import ( + check_and_set_migration_status, + is_migration_complete, + reset_migration_status, + ) + + # Reset to ensure clean state + await reset_migration_status(async_redis_client) + + # Create an old-format string key + key = Keys.working_memory_key( + session_id="test-string-session", namespace="test-namespace" + ) + old_data = { + "messages": [], + "memories": [], + "session_id": "test-string-session", + "namespace": "test-namespace", + } + await async_redis_client.set(key, json.dumps(old_data)) + + # Check status - should NOT mark as migrated (string keys exist) + result = await check_and_set_migration_status(async_redis_client) + assert result is False + assert not await is_migration_complete(async_redis_client) + + @pytest.mark.asyncio + async def test_migration_status_set_by_set_migration_complete( + self, async_redis_client + ): + """Test that set_migration_complete() marks migration as done.""" + import json + + from agent_memory_server.working_memory import ( + check_and_set_migration_status, + get_remaining_string_keys, + is_migration_complete, + reset_migration_status, + ) + + # Reset to ensure clean state + await reset_migration_status(async_redis_client) + + # Clean up any existing working_memory keys from other tests + cursor = 0 + while True: + cursor, keys = await async_redis_client.scan( + cursor=cursor, match="working_memory:*", count=100 + ) + if keys: + await async_redis_client.delete(*keys) + if cursor == 0: + break + + # Create two old-format string keys + for i in range(2): + key = Keys.working_memory_key( + session_id=f"test-migrate-session-{i}", namespace="test-namespace" + ) + old_data = { + "messages": [], + "memories": [], + "session_id": f"test-migrate-session-{i}", + "namespace": "test-namespace", + "context": None, + "user_id": None, + "tokens": 0, + "ttl_seconds": None, + "data": {}, + "long_term_memory_strategy": {"strategy": "discrete"}, + "last_accessed": 1704067200, + "created_at": 1704067200, + "updated_at": 1704067200, + } + await async_redis_client.set(key, json.dumps(old_data)) + + # Check status - should NOT be migrated, counter should be 2 + await check_and_set_migration_status(async_redis_client) + assert not await is_migration_complete(async_redis_client) + assert await get_remaining_string_keys(async_redis_client) == 2 + + # Read first key - triggers migration, counter decrements to 1 + await get_working_memory( + session_id="test-migrate-session-0", + namespace="test-namespace", + redis_client=async_redis_client, + ) + assert not await is_migration_complete(async_redis_client) + assert await get_remaining_string_keys(async_redis_client) == 1 + + # Read second key - triggers migration, counter decrements to 0, auto-completes + await get_working_memory( + session_id="test-migrate-session-1", + namespace="test-namespace", + redis_client=async_redis_client, + ) + # Now migration should be complete (auto-completed when counter hit 0) + assert await is_migration_complete(async_redis_client) + assert await get_remaining_string_keys(async_redis_client) == 0 + + @pytest.mark.asyncio + async def test_migration_skipped_when_env_variable_set( + self, async_redis_client, monkeypatch + ): + """Test that migration check is skipped when WORKING_MEMORY_MIGRATION_COMPLETE=true.""" + import json + + from agent_memory_server import config + from agent_memory_server.working_memory import ( + check_and_set_migration_status, + is_migration_complete, + reset_migration_status, + ) + + # Reset to ensure clean state + await reset_migration_status(async_redis_client) + + # Create an old-format string key (would normally trigger lazy migration) + key = Keys.working_memory_key( + session_id="test-env-skip-session", namespace="test-namespace" + ) + old_data = { + "messages": [], + "memories": [], + "session_id": "test-env-skip-session", + "namespace": "test-namespace", + } + await async_redis_client.set(key, json.dumps(old_data)) + + # Set the env variable via settings + monkeypatch.setattr(config.settings, "working_memory_migration_complete", True) + + # Check status - should skip scan and mark as complete immediately + result = await check_and_set_migration_status(async_redis_client) + assert result is True + assert await is_migration_complete(async_redis_client) + + # Clean up + await async_redis_client.delete(key) + monkeypatch.setattr(config.settings, "working_memory_migration_complete", False) diff --git a/tests/test_working_memory_strategies.py b/tests/test_working_memory_strategies.py index f06ef5c..b5f48c2 100644 --- a/tests/test_working_memory_strategies.py +++ b/tests/test_working_memory_strategies.py @@ -1,7 +1,6 @@ """Tests for working memory strategy integration.""" -import json -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -216,19 +215,22 @@ async def test_set_working_memory_with_strategy(self): with patch( "agent_memory_server.working_memory.get_redis_conn" ) as mock_get_redis: - mock_redis = AsyncMock() + mock_redis = MagicMock() + mock_redis.expire = AsyncMock() + # json() is synchronous but returns an object with async methods + mock_json = MagicMock() + mock_json.set = AsyncMock() + mock_redis.json.return_value = mock_json mock_get_redis.return_value = mock_redis await set_working_memory(memory, mock_redis) - # Verify Redis set was called - mock_redis.set.assert_called_once() - call_args = mock_redis.set.call_args - - # Parse the stored data to verify strategy was included - import json + # Verify Redis JSON set was called + mock_json.set.assert_called_once() + call_args = mock_json.set.call_args - stored_data = json.loads(call_args[0][1]) + # The data is passed directly as a dict (not JSON string) to redis.json().set() + stored_data = call_args[0][2] # Third positional arg is the data assert "long_term_memory_strategy" in stored_data assert stored_data["long_term_memory_strategy"]["strategy"] == "summary" assert ( @@ -259,11 +261,22 @@ async def test_get_working_memory_with_strategy(self): "updated_at": 1640995200, } - with patch( - "agent_memory_server.working_memory.get_redis_conn" - ) as mock_get_redis: + with ( + patch( + "agent_memory_server.working_memory.get_redis_conn" + ) as mock_get_redis, + patch( + "agent_memory_server.working_memory.is_migration_complete" + ) as mock_is_migration_complete, + ): mock_redis = AsyncMock() - mock_redis.get.return_value = json.dumps(stored_data).encode() + # Mock is_migration_complete to return True (fast path) + mock_is_migration_complete.return_value = True + # json() is synchronous but returns an object with async methods + mock_json = MagicMock() + # Redis JSON returns dict directly, not bytes + mock_json.get = AsyncMock(return_value=stored_data) + mock_redis.json = MagicMock(return_value=mock_json) mock_get_redis.return_value = mock_redis result = await get_working_memory( @@ -297,11 +310,22 @@ async def test_get_working_memory_without_strategy_uses_default(self): "updated_at": 1640995200, } - with patch( - "agent_memory_server.working_memory.get_redis_conn" - ) as mock_get_redis: + with ( + patch( + "agent_memory_server.working_memory.get_redis_conn" + ) as mock_get_redis, + patch( + "agent_memory_server.working_memory.is_migration_complete" + ) as mock_is_migration_complete, + ): mock_redis = AsyncMock() - mock_redis.get.return_value = json.dumps(stored_data).encode() + # Mock is_migration_complete to return True (fast path) + mock_is_migration_complete.return_value = True + # json() is synchronous but returns an object with async methods + mock_json = MagicMock() + # Redis JSON returns dict directly, not bytes + mock_json.get = AsyncMock(return_value=stored_data) + mock_redis.json = MagicMock(return_value=mock_json) mock_get_redis.return_value = mock_redis result = await get_working_memory(