Skip to content

Commit 5991308

Browse files
authored
fix: batch queries on recall (#149)
* fix: batch queries on recall * fix: batch queries on recall
1 parent 7935b0a commit 5991308

File tree

6 files changed

+854
-30
lines changed

6 files changed

+854
-30
lines changed

hindsight-api/hindsight_api/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
ENV_GRAPH_RETRIEVER = "HINDSIGHT_API_GRAPH_RETRIEVER"
6565
ENV_MPFP_TOP_K_NEIGHBORS = "HINDSIGHT_API_MPFP_TOP_K_NEIGHBORS"
6666
ENV_RECALL_MAX_CONCURRENT = "HINDSIGHT_API_RECALL_MAX_CONCURRENT"
67+
ENV_RECALL_CONNECTION_BUDGET = "HINDSIGHT_API_RECALL_CONNECTION_BUDGET"
6768
ENV_MCP_LOCAL_BANK_ID = "HINDSIGHT_API_MCP_LOCAL_BANK_ID"
6869
ENV_MCP_INSTRUCTIONS = "HINDSIGHT_API_MCP_INSTRUCTIONS"
6970

@@ -128,6 +129,7 @@
128129
DEFAULT_GRAPH_RETRIEVER = "link_expansion" # Options: "link_expansion", "mpfp", "bfs"
129130
DEFAULT_MPFP_TOP_K_NEIGHBORS = 20 # Fan-out limit per node in MPFP graph traversal
130131
DEFAULT_RECALL_MAX_CONCURRENT = 32 # Max concurrent recall operations per worker
132+
DEFAULT_RECALL_CONNECTION_BUDGET = 4 # Max concurrent DB connections per recall operation
131133
DEFAULT_MCP_LOCAL_BANK_ID = "mcp"
132134

133135
# Observation thresholds
@@ -241,6 +243,7 @@ class HindsightConfig:
241243
graph_retriever: str
242244
mpfp_top_k_neighbors: int
243245
recall_max_concurrent: int
246+
recall_connection_budget: int
244247

245248
# Observation thresholds
246249
observation_min_facts: int
@@ -315,6 +318,9 @@ def from_env(cls) -> "HindsightConfig":
315318
graph_retriever=os.getenv(ENV_GRAPH_RETRIEVER, DEFAULT_GRAPH_RETRIEVER),
316319
mpfp_top_k_neighbors=int(os.getenv(ENV_MPFP_TOP_K_NEIGHBORS, str(DEFAULT_MPFP_TOP_K_NEIGHBORS))),
317320
recall_max_concurrent=int(os.getenv(ENV_RECALL_MAX_CONCURRENT, str(DEFAULT_RECALL_MAX_CONCURRENT))),
321+
recall_connection_budget=int(
322+
os.getenv(ENV_RECALL_CONNECTION_BUDGET, str(DEFAULT_RECALL_CONNECTION_BUDGET))
323+
),
318324
# Optimization flags
319325
skip_llm_verification=os.getenv(ENV_SKIP_LLM_VERIFICATION, "false").lower() == "true",
320326
lazy_reranker=os.getenv(ENV_LAZY_RERANKER, "false").lower() == "true",
Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
"""
2+
Database connection budget management.
3+
4+
Limits concurrent database connections per operation to prevent
5+
a single operation (e.g., recall with parallel queries) from
6+
exhausting the connection pool.
7+
"""
8+
9+
import asyncio
10+
import logging
11+
import uuid
12+
from contextlib import asynccontextmanager
13+
from dataclasses import dataclass, field
14+
from typing import TYPE_CHECKING, AsyncIterator
15+
16+
if TYPE_CHECKING:
17+
import asyncpg
18+
19+
logger = logging.getLogger(__name__)
20+
21+
22+
@dataclass
23+
class OperationBudget:
24+
"""
25+
Tracks connection budget for a single operation.
26+
27+
Each operation gets a semaphore limiting its concurrent connections.
28+
"""
29+
30+
operation_id: str
31+
max_connections: int
32+
semaphore: asyncio.Semaphore = field(init=False)
33+
active_count: int = field(default=0, init=False)
34+
35+
def __post_init__(self):
36+
self.semaphore = asyncio.Semaphore(self.max_connections)
37+
38+
39+
class ConnectionBudgetManager:
40+
"""
41+
Manages per-operation connection budgets.
42+
43+
Usage:
44+
manager = ConnectionBudgetManager(default_budget=4)
45+
46+
# Start an operation
47+
async with manager.operation(max_connections=2) as op:
48+
# Acquire connections within the budget
49+
async with op.acquire(pool) as conn:
50+
await conn.fetch(...)
51+
52+
# Multiple connections respect the budget
53+
async with op.acquire(pool) as conn1, op.acquire(pool) as conn2:
54+
# At most 2 concurrent connections for this operation
55+
...
56+
"""
57+
58+
def __init__(self, default_budget: int = 4):
59+
"""
60+
Initialize the budget manager.
61+
62+
Args:
63+
default_budget: Default max connections per operation
64+
"""
65+
self.default_budget = default_budget
66+
self._operations: dict[str, OperationBudget] = {}
67+
self._lock = asyncio.Lock()
68+
69+
@asynccontextmanager
70+
async def operation(
71+
self,
72+
max_connections: int | None = None,
73+
operation_id: str | None = None,
74+
) -> AsyncIterator["BudgetedOperation"]:
75+
"""
76+
Create a budgeted operation context.
77+
78+
Args:
79+
max_connections: Max concurrent connections for this operation.
80+
Defaults to manager's default_budget.
81+
operation_id: Optional custom operation ID. Auto-generated if not provided.
82+
83+
Yields:
84+
BudgetedOperation context for acquiring connections
85+
"""
86+
op_id = operation_id or f"op-{uuid.uuid4().hex[:12]}"
87+
budget = max_connections or self.default_budget
88+
89+
async with self._lock:
90+
if op_id in self._operations:
91+
raise ValueError(f"Operation {op_id} already exists")
92+
self._operations[op_id] = OperationBudget(op_id, budget)
93+
94+
try:
95+
yield BudgetedOperation(self, op_id)
96+
finally:
97+
async with self._lock:
98+
self._operations.pop(op_id, None)
99+
100+
def _get_budget(self, operation_id: str) -> OperationBudget:
101+
"""Get budget for an operation (internal use)."""
102+
budget = self._operations.get(operation_id)
103+
if not budget:
104+
raise ValueError(f"Operation {operation_id} not found")
105+
return budget
106+
107+
108+
class BudgetedOperation:
109+
"""
110+
A single operation with connection budget.
111+
112+
Provides methods to acquire connections within the budget.
113+
"""
114+
115+
def __init__(self, manager: ConnectionBudgetManager, operation_id: str):
116+
self._manager = manager
117+
self.operation_id = operation_id
118+
119+
@property
120+
def budget(self) -> OperationBudget:
121+
"""Get the budget for this operation."""
122+
return self._manager._get_budget(self.operation_id)
123+
124+
@asynccontextmanager
125+
async def acquire(self, pool: "asyncpg.Pool") -> AsyncIterator["asyncpg.Connection"]:
126+
"""
127+
Acquire a connection within the operation's budget.
128+
129+
Blocks if the operation has reached its connection limit.
130+
131+
Args:
132+
pool: asyncpg connection pool
133+
134+
Yields:
135+
Database connection
136+
"""
137+
budget = self.budget
138+
async with budget.semaphore:
139+
budget.active_count += 1
140+
conn = await pool.acquire()
141+
try:
142+
yield conn
143+
finally:
144+
budget.active_count -= 1
145+
await pool.release(conn)
146+
147+
def wrap_pool(self, pool: "asyncpg.Pool") -> "BudgetedPool":
148+
"""
149+
Wrap a pool with this operation's budget.
150+
151+
The returned BudgetedPool can be passed to functions expecting a pool,
152+
and all acquire() calls will be limited by this operation's budget.
153+
154+
Args:
155+
pool: asyncpg connection pool to wrap
156+
157+
Returns:
158+
BudgetedPool that limits connections to this operation's budget
159+
"""
160+
return BudgetedPool(pool, self)
161+
162+
async def acquire_many(
163+
self,
164+
pool: "asyncpg.Pool",
165+
count: int,
166+
) -> AsyncIterator[list["asyncpg.Connection"]]:
167+
"""
168+
Acquire multiple connections within the budget.
169+
170+
Note: This acquires connections sequentially to respect the budget.
171+
For parallel acquisition, use multiple acquire() calls with asyncio.gather().
172+
173+
Args:
174+
pool: asyncpg connection pool
175+
count: Number of connections to acquire
176+
177+
Yields:
178+
List of database connections
179+
"""
180+
connections = []
181+
try:
182+
for _ in range(count):
183+
conn = await pool.acquire()
184+
connections.append(conn)
185+
yield connections
186+
finally:
187+
for conn in connections:
188+
await pool.release(conn)
189+
190+
191+
# Global default manager instance
192+
_default_manager: ConnectionBudgetManager | None = None
193+
194+
195+
def get_budget_manager(default_budget: int = 4) -> ConnectionBudgetManager:
196+
"""
197+
Get or create the global budget manager.
198+
199+
Args:
200+
default_budget: Default max connections per operation
201+
202+
Returns:
203+
Global ConnectionBudgetManager instance
204+
"""
205+
global _default_manager
206+
if _default_manager is None:
207+
_default_manager = ConnectionBudgetManager(default_budget=default_budget)
208+
return _default_manager
209+
210+
211+
@asynccontextmanager
212+
async def budgeted_operation(
213+
max_connections: int | None = None,
214+
operation_id: str | None = None,
215+
default_budget: int = 4,
216+
) -> AsyncIterator[BudgetedOperation]:
217+
"""
218+
Convenience function to create a budgeted operation.
219+
220+
Args:
221+
max_connections: Max concurrent connections for this operation
222+
operation_id: Optional custom operation ID
223+
default_budget: Default budget if manager not yet created
224+
225+
Yields:
226+
BudgetedOperation context
227+
228+
Example:
229+
async with budgeted_operation(max_connections=2) as op:
230+
async with op.acquire(pool) as conn:
231+
await conn.fetch(...)
232+
"""
233+
manager = get_budget_manager(default_budget)
234+
async with manager.operation(max_connections, operation_id) as op:
235+
yield op
236+
237+
238+
class BudgetedPool:
239+
"""
240+
A pool wrapper that limits concurrent connection acquisitions.
241+
242+
This can be passed to functions expecting a pool, and acquire()
243+
calls will be limited by the budget semaphore.
244+
245+
Usage:
246+
async with budgeted_operation(max_connections=4) as op:
247+
budgeted_pool = op.wrap_pool(pool)
248+
# Pass budgeted_pool to functions that expect a pool
249+
await some_function(budgeted_pool, ...)
250+
"""
251+
252+
def __init__(self, pool: "asyncpg.Pool", operation: BudgetedOperation):
253+
self._pool = pool
254+
self._operation = operation
255+
256+
async def acquire(self) -> "asyncpg.Connection":
257+
"""
258+
Acquire a connection within the budget.
259+
260+
Note: Caller must release the connection when done.
261+
Prefer using as context manager via acquire_with_retry or op.acquire().
262+
"""
263+
budget = self._operation.budget
264+
await budget.semaphore.acquire()
265+
budget.active_count += 1
266+
try:
267+
return await self._pool.acquire()
268+
except Exception:
269+
budget.active_count -= 1
270+
budget.semaphore.release()
271+
raise
272+
273+
async def release(self, conn: "asyncpg.Connection") -> None:
274+
"""Release a connection back to the pool."""
275+
budget = self._operation.budget
276+
try:
277+
await self._pool.release(conn)
278+
finally:
279+
budget.active_count -= 1
280+
budget.semaphore.release()
281+
282+
def __getattr__(self, name):
283+
"""Proxy other attributes to the underlying pool."""
284+
return getattr(self._pool, name)

0 commit comments

Comments
 (0)