Skip to content

Commit 63f5138

Browse files
authored
fix: retain async fails (#40)
* fix: retain async fails * fix: retain async fails
1 parent e468a4e commit 63f5138

File tree

3 files changed

+199
-13
lines changed

3 files changed

+199
-13
lines changed

hindsight-api/hindsight_api/api/http.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,18 +1452,22 @@ async def api_list_operations(bank_id: str):
14521452
bank_id,
14531453
)
14541454

1455+
def parse_metadata(metadata):
1456+
"""Parse result_metadata which may be a string or dict."""
1457+
if metadata is None:
1458+
return {}
1459+
if isinstance(metadata, str):
1460+
return json.loads(metadata)
1461+
return metadata
1462+
14551463
return {
14561464
"bank_id": bank_id,
14571465
"operations": [
14581466
{
14591467
"id": str(row["operation_id"]),
14601468
"task_type": row["operation_type"],
1461-
"items_count": row["result_metadata"].get("items_count", 0)
1462-
if row["result_metadata"]
1463-
else 0,
1464-
"document_id": row["result_metadata"].get("document_id")
1465-
if row["result_metadata"]
1466-
else None,
1469+
"items_count": parse_metadata(row["result_metadata"]).get("items_count", 0),
1470+
"document_id": parse_metadata(row["result_metadata"]).get("document_id"),
14671471
"created_at": row["created_at"].isoformat(),
14681472
"status": row["status"],
14691473
"error_message": row["error_message"],
@@ -1499,7 +1503,7 @@ async def api_cancel_operation(bank_id: str, operation_id: str):
14991503
async with acquire_with_retry(pool) as conn:
15001504
# Check if operation exists and belongs to this memory bank
15011505
result = await conn.fetchrow(
1502-
"SELECT bank_id FROM async_operations WHERE id = $1 AND bank_id = $2", op_uuid, bank_id
1506+
"SELECT bank_id FROM async_operations WHERE operation_id = $1 AND bank_id = $2", op_uuid, bank_id
15031507
)
15041508

15051509
if not result:
@@ -1508,7 +1512,7 @@ async def api_cancel_operation(bank_id: str, operation_id: str):
15081512
)
15091513

15101514
# Delete the operation
1511-
await conn.execute("DELETE FROM async_operations WHERE id = $1", op_uuid)
1515+
await conn.execute("DELETE FROM async_operations WHERE operation_id = $1", op_uuid)
15121516

15131517
return {
15141518
"success": True,
@@ -1769,13 +1773,13 @@ async def api_retain(bank_id: str, request: RetainRequest):
17691773
async with acquire_with_retry(pool) as conn:
17701774
await conn.execute(
17711775
"""
1772-
INSERT INTO async_operations (id, bank_id, task_type, items_count)
1776+
INSERT INTO async_operations (operation_id, bank_id, operation_type, result_metadata)
17731777
VALUES ($1, $2, $3, $4)
17741778
""",
17751779
operation_id,
17761780
bank_id,
17771781
"retain",
1778-
len(contents),
1782+
json.dumps({"items_count": len(contents)}),
17791783
)
17801784

17811785
# Submit task to background queue

hindsight-api/hindsight_api/engine/memory_engine.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ async def execute_task(self, task_dict: dict[str, Any]):
311311
pool = await self._get_pool()
312312
async with acquire_with_retry(pool) as conn:
313313
result = await conn.fetchrow(
314-
"SELECT id FROM async_operations WHERE id = $1", uuid.UUID(operation_id)
314+
"SELECT operation_id FROM async_operations WHERE operation_id = $1", uuid.UUID(operation_id)
315315
)
316316
if not result:
317317
# Operation was cancelled, skip processing
@@ -369,7 +369,7 @@ async def _delete_operation_record(self, operation_id: str):
369369
try:
370370
pool = await self._get_pool()
371371
async with acquire_with_retry(pool) as conn:
372-
await conn.execute("DELETE FROM async_operations WHERE id = $1", uuid.UUID(operation_id))
372+
await conn.execute("DELETE FROM async_operations WHERE operation_id = $1", uuid.UUID(operation_id))
373373
except Exception as e:
374374
logger.error(f"Failed to delete async operation record {operation_id}: {e}")
375375

@@ -386,7 +386,7 @@ async def _mark_operation_failed(self, operation_id: str, error_message: str, er
386386
"""
387387
UPDATE async_operations
388388
SET status = 'failed', error_message = $2
389-
WHERE id = $1
389+
WHERE operation_id = $1
390390
""",
391391
uuid.UUID(operation_id),
392392
truncated_error,

hindsight-api/tests/test_http_api_integration.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,3 +426,185 @@ async def test_document_deletion(api_client):
426426
f"/v1/default/banks/{test_bank_id}/documents/sales-report-q1-2024"
427427
)
428428
assert response.status_code == 404
429+
430+
431+
@pytest.mark.asyncio
432+
async def test_async_retain(api_client):
433+
"""Test asynchronous retain functionality.
434+
435+
When async=true is passed, the retain endpoint should:
436+
1. Return immediately with success and async_=true
437+
2. Process the content in the background
438+
3. Eventually store the memories
439+
"""
440+
import asyncio
441+
442+
test_bank_id = f"async_retain_test_{datetime.now().timestamp()}"
443+
444+
# Store memory with async=true
445+
response = await api_client.post(
446+
f"/v1/default/banks/{test_bank_id}/memories",
447+
json={
448+
"async": True,
449+
"items": [
450+
{
451+
"content": "Alice is a senior engineer at TechCorp. She has been working on the authentication system for 5 years.",
452+
"context": "team introduction"
453+
}
454+
]
455+
}
456+
)
457+
assert response.status_code == 200
458+
result = response.json()
459+
assert result["success"] is True
460+
assert result["async"] is True, "Response should indicate async processing"
461+
assert result["items_count"] == 1
462+
463+
# Check operations endpoint to see the pending operation
464+
response = await api_client.get(f"/v1/default/banks/{test_bank_id}/operations")
465+
assert response.status_code == 200
466+
ops_result = response.json()
467+
assert "operations" in ops_result
468+
469+
# Wait for async processing to complete (poll with timeout)
470+
max_wait_seconds = 30
471+
poll_interval = 0.5
472+
elapsed = 0
473+
memories_found = False
474+
475+
while elapsed < max_wait_seconds:
476+
# Check if memories are stored
477+
response = await api_client.get(
478+
f"/v1/default/banks/{test_bank_id}/memories/list",
479+
params={"limit": 10}
480+
)
481+
assert response.status_code == 200
482+
items = response.json()["items"]
483+
484+
if len(items) > 0:
485+
memories_found = True
486+
break
487+
488+
await asyncio.sleep(poll_interval)
489+
elapsed += poll_interval
490+
491+
assert memories_found, f"Async retain did not complete within {max_wait_seconds} seconds"
492+
493+
# Verify we can recall the stored memory
494+
response = await api_client.post(
495+
f"/v1/default/banks/{test_bank_id}/memories/recall",
496+
json={
497+
"query": "Who works at TechCorp?",
498+
"thinking_budget": 30
499+
}
500+
)
501+
assert response.status_code == 200
502+
search_results = response.json()
503+
assert len(search_results["results"]) > 0, "Should find the asynchronously stored memory"
504+
505+
# Verify Alice is mentioned
506+
found_alice = any("Alice" in r["text"] for r in search_results["results"])
507+
assert found_alice, "Should find Alice in search results"
508+
509+
510+
@pytest.mark.asyncio
511+
async def test_async_retain_parallel(api_client):
512+
"""Test multiple async retain operations running in parallel.
513+
514+
Verifies that:
515+
1. Multiple async operations can be submitted concurrently
516+
2. All operations complete successfully
517+
3. The exact number of documents are processed
518+
"""
519+
import asyncio
520+
521+
test_bank_id = f"async_parallel_test_{datetime.now().timestamp()}"
522+
num_documents = 5
523+
524+
# Prepare multiple documents to retain
525+
documents = [
526+
{
527+
"content": f"Document {i}: This is test content about Person{i} who works at Company{i}.",
528+
"context": f"test document {i}",
529+
"document_id": f"doc_{i}"
530+
}
531+
for i in range(num_documents)
532+
]
533+
534+
# Submit all async retain operations in parallel
535+
async def submit_async_retain(doc):
536+
return await api_client.post(
537+
f"/v1/default/banks/{test_bank_id}/memories",
538+
json={
539+
"async": True,
540+
"items": [doc]
541+
}
542+
)
543+
544+
# Run all submissions concurrently
545+
responses = await asyncio.gather(*[submit_async_retain(doc) for doc in documents])
546+
547+
# Verify all submissions succeeded
548+
for i, response in enumerate(responses):
549+
assert response.status_code == 200, f"Document {i} submission failed"
550+
result = response.json()
551+
assert result["success"] is True
552+
assert result["async"] is True
553+
554+
# Check operations endpoint - should show pending operations
555+
response = await api_client.get(f"/v1/default/banks/{test_bank_id}/operations")
556+
assert response.status_code == 200
557+
558+
# Wait for all async operations to complete (poll with timeout)
559+
max_wait_seconds = 60
560+
poll_interval = 1.0
561+
elapsed = 0
562+
all_docs_processed = False
563+
564+
while elapsed < max_wait_seconds:
565+
# Check document count
566+
response = await api_client.get(f"/v1/default/banks/{test_bank_id}/documents")
567+
assert response.status_code == 200
568+
docs = response.json()["items"]
569+
570+
if len(docs) >= num_documents:
571+
all_docs_processed = True
572+
break
573+
574+
await asyncio.sleep(poll_interval)
575+
elapsed += poll_interval
576+
577+
assert all_docs_processed, f"Expected {num_documents} documents, but only {len(docs)} were processed within {max_wait_seconds} seconds"
578+
579+
# Verify exact document count
580+
response = await api_client.get(f"/v1/default/banks/{test_bank_id}/documents")
581+
assert response.status_code == 200
582+
final_docs = response.json()["items"]
583+
assert len(final_docs) == num_documents, f"Expected exactly {num_documents} documents, got {len(final_docs)}"
584+
585+
# Verify each document exists
586+
doc_ids = {doc["id"] for doc in final_docs}
587+
for i in range(num_documents):
588+
assert f"doc_{i}" in doc_ids, f"Document doc_{i} not found"
589+
590+
# Verify memories were created for all documents
591+
response = await api_client.get(
592+
f"/v1/default/banks/{test_bank_id}/memories/list",
593+
params={"limit": 100}
594+
)
595+
assert response.status_code == 200
596+
memories = response.json()["items"]
597+
assert len(memories) >= num_documents, f"Expected at least {num_documents} memories, got {len(memories)}"
598+
599+
# Verify we can recall content from different documents
600+
for i in [0, num_documents - 1]: # Check first and last
601+
response = await api_client.post(
602+
f"/v1/default/banks/{test_bank_id}/memories/recall",
603+
json={
604+
"query": f"Who works at Company{i}?",
605+
"thinking_budget": 30
606+
}
607+
)
608+
assert response.status_code == 200
609+
results = response.json()["results"]
610+
assert len(results) > 0, f"Should find memories for document {i}"

0 commit comments

Comments
 (0)