|
| 1 | +""" |
| 2 | +Database connection budget management. |
| 3 | +
|
| 4 | +Limits concurrent database connections per operation to prevent |
| 5 | +a single operation (e.g., recall with parallel queries) from |
| 6 | +exhausting the connection pool. |
| 7 | +""" |
| 8 | + |
| 9 | +import asyncio |
| 10 | +import logging |
| 11 | +import uuid |
| 12 | +from contextlib import asynccontextmanager |
| 13 | +from dataclasses import dataclass, field |
| 14 | +from typing import TYPE_CHECKING, AsyncIterator |
| 15 | + |
| 16 | +if TYPE_CHECKING: |
| 17 | + import asyncpg |
| 18 | + |
| 19 | +logger = logging.getLogger(__name__) |
| 20 | + |
| 21 | + |
| 22 | +@dataclass |
| 23 | +class OperationBudget: |
| 24 | + """ |
| 25 | + Tracks connection budget for a single operation. |
| 26 | +
|
| 27 | + Each operation gets a semaphore limiting its concurrent connections. |
| 28 | + """ |
| 29 | + |
| 30 | + operation_id: str |
| 31 | + max_connections: int |
| 32 | + semaphore: asyncio.Semaphore = field(init=False) |
| 33 | + active_count: int = field(default=0, init=False) |
| 34 | + |
| 35 | + def __post_init__(self): |
| 36 | + self.semaphore = asyncio.Semaphore(self.max_connections) |
| 37 | + |
| 38 | + |
| 39 | +class ConnectionBudgetManager: |
| 40 | + """ |
| 41 | + Manages per-operation connection budgets. |
| 42 | +
|
| 43 | + Usage: |
| 44 | + manager = ConnectionBudgetManager(default_budget=4) |
| 45 | +
|
| 46 | + # Start an operation |
| 47 | + async with manager.operation(max_connections=2) as op: |
| 48 | + # Acquire connections within the budget |
| 49 | + async with op.acquire(pool) as conn: |
| 50 | + await conn.fetch(...) |
| 51 | +
|
| 52 | + # Multiple connections respect the budget |
| 53 | + async with op.acquire(pool) as conn1, op.acquire(pool) as conn2: |
| 54 | + # At most 2 concurrent connections for this operation |
| 55 | + ... |
| 56 | + """ |
| 57 | + |
| 58 | + def __init__(self, default_budget: int = 4): |
| 59 | + """ |
| 60 | + Initialize the budget manager. |
| 61 | +
|
| 62 | + Args: |
| 63 | + default_budget: Default max connections per operation |
| 64 | + """ |
| 65 | + self.default_budget = default_budget |
| 66 | + self._operations: dict[str, OperationBudget] = {} |
| 67 | + self._lock = asyncio.Lock() |
| 68 | + |
| 69 | + @asynccontextmanager |
| 70 | + async def operation( |
| 71 | + self, |
| 72 | + max_connections: int | None = None, |
| 73 | + operation_id: str | None = None, |
| 74 | + ) -> AsyncIterator["BudgetedOperation"]: |
| 75 | + """ |
| 76 | + Create a budgeted operation context. |
| 77 | +
|
| 78 | + Args: |
| 79 | + max_connections: Max concurrent connections for this operation. |
| 80 | + Defaults to manager's default_budget. |
| 81 | + operation_id: Optional custom operation ID. Auto-generated if not provided. |
| 82 | +
|
| 83 | + Yields: |
| 84 | + BudgetedOperation context for acquiring connections |
| 85 | + """ |
| 86 | + op_id = operation_id or f"op-{uuid.uuid4().hex[:12]}" |
| 87 | + budget = max_connections or self.default_budget |
| 88 | + |
| 89 | + async with self._lock: |
| 90 | + if op_id in self._operations: |
| 91 | + raise ValueError(f"Operation {op_id} already exists") |
| 92 | + self._operations[op_id] = OperationBudget(op_id, budget) |
| 93 | + |
| 94 | + try: |
| 95 | + yield BudgetedOperation(self, op_id) |
| 96 | + finally: |
| 97 | + async with self._lock: |
| 98 | + self._operations.pop(op_id, None) |
| 99 | + |
| 100 | + def _get_budget(self, operation_id: str) -> OperationBudget: |
| 101 | + """Get budget for an operation (internal use).""" |
| 102 | + budget = self._operations.get(operation_id) |
| 103 | + if not budget: |
| 104 | + raise ValueError(f"Operation {operation_id} not found") |
| 105 | + return budget |
| 106 | + |
| 107 | + |
| 108 | +class BudgetedOperation: |
| 109 | + """ |
| 110 | + A single operation with connection budget. |
| 111 | +
|
| 112 | + Provides methods to acquire connections within the budget. |
| 113 | + """ |
| 114 | + |
| 115 | + def __init__(self, manager: ConnectionBudgetManager, operation_id: str): |
| 116 | + self._manager = manager |
| 117 | + self.operation_id = operation_id |
| 118 | + |
| 119 | + @property |
| 120 | + def budget(self) -> OperationBudget: |
| 121 | + """Get the budget for this operation.""" |
| 122 | + return self._manager._get_budget(self.operation_id) |
| 123 | + |
| 124 | + @asynccontextmanager |
| 125 | + async def acquire(self, pool: "asyncpg.Pool") -> AsyncIterator["asyncpg.Connection"]: |
| 126 | + """ |
| 127 | + Acquire a connection within the operation's budget. |
| 128 | +
|
| 129 | + Blocks if the operation has reached its connection limit. |
| 130 | +
|
| 131 | + Args: |
| 132 | + pool: asyncpg connection pool |
| 133 | +
|
| 134 | + Yields: |
| 135 | + Database connection |
| 136 | + """ |
| 137 | + budget = self.budget |
| 138 | + async with budget.semaphore: |
| 139 | + budget.active_count += 1 |
| 140 | + conn = await pool.acquire() |
| 141 | + try: |
| 142 | + yield conn |
| 143 | + finally: |
| 144 | + budget.active_count -= 1 |
| 145 | + await pool.release(conn) |
| 146 | + |
| 147 | + def wrap_pool(self, pool: "asyncpg.Pool") -> "BudgetedPool": |
| 148 | + """ |
| 149 | + Wrap a pool with this operation's budget. |
| 150 | +
|
| 151 | + The returned BudgetedPool can be passed to functions expecting a pool, |
| 152 | + and all acquire() calls will be limited by this operation's budget. |
| 153 | +
|
| 154 | + Args: |
| 155 | + pool: asyncpg connection pool to wrap |
| 156 | +
|
| 157 | + Returns: |
| 158 | + BudgetedPool that limits connections to this operation's budget |
| 159 | + """ |
| 160 | + return BudgetedPool(pool, self) |
| 161 | + |
| 162 | + async def acquire_many( |
| 163 | + self, |
| 164 | + pool: "asyncpg.Pool", |
| 165 | + count: int, |
| 166 | + ) -> AsyncIterator[list["asyncpg.Connection"]]: |
| 167 | + """ |
| 168 | + Acquire multiple connections within the budget. |
| 169 | +
|
| 170 | + Note: This acquires connections sequentially to respect the budget. |
| 171 | + For parallel acquisition, use multiple acquire() calls with asyncio.gather(). |
| 172 | +
|
| 173 | + Args: |
| 174 | + pool: asyncpg connection pool |
| 175 | + count: Number of connections to acquire |
| 176 | +
|
| 177 | + Yields: |
| 178 | + List of database connections |
| 179 | + """ |
| 180 | + connections = [] |
| 181 | + try: |
| 182 | + for _ in range(count): |
| 183 | + conn = await pool.acquire() |
| 184 | + connections.append(conn) |
| 185 | + yield connections |
| 186 | + finally: |
| 187 | + for conn in connections: |
| 188 | + await pool.release(conn) |
| 189 | + |
| 190 | + |
| 191 | +# Global default manager instance |
| 192 | +_default_manager: ConnectionBudgetManager | None = None |
| 193 | + |
| 194 | + |
| 195 | +def get_budget_manager(default_budget: int = 4) -> ConnectionBudgetManager: |
| 196 | + """ |
| 197 | + Get or create the global budget manager. |
| 198 | +
|
| 199 | + Args: |
| 200 | + default_budget: Default max connections per operation |
| 201 | +
|
| 202 | + Returns: |
| 203 | + Global ConnectionBudgetManager instance |
| 204 | + """ |
| 205 | + global _default_manager |
| 206 | + if _default_manager is None: |
| 207 | + _default_manager = ConnectionBudgetManager(default_budget=default_budget) |
| 208 | + return _default_manager |
| 209 | + |
| 210 | + |
| 211 | +@asynccontextmanager |
| 212 | +async def budgeted_operation( |
| 213 | + max_connections: int | None = None, |
| 214 | + operation_id: str | None = None, |
| 215 | + default_budget: int = 4, |
| 216 | +) -> AsyncIterator[BudgetedOperation]: |
| 217 | + """ |
| 218 | + Convenience function to create a budgeted operation. |
| 219 | +
|
| 220 | + Args: |
| 221 | + max_connections: Max concurrent connections for this operation |
| 222 | + operation_id: Optional custom operation ID |
| 223 | + default_budget: Default budget if manager not yet created |
| 224 | +
|
| 225 | + Yields: |
| 226 | + BudgetedOperation context |
| 227 | +
|
| 228 | + Example: |
| 229 | + async with budgeted_operation(max_connections=2) as op: |
| 230 | + async with op.acquire(pool) as conn: |
| 231 | + await conn.fetch(...) |
| 232 | + """ |
| 233 | + manager = get_budget_manager(default_budget) |
| 234 | + async with manager.operation(max_connections, operation_id) as op: |
| 235 | + yield op |
| 236 | + |
| 237 | + |
| 238 | +class BudgetedPool: |
| 239 | + """ |
| 240 | + A pool wrapper that limits concurrent connection acquisitions. |
| 241 | +
|
| 242 | + This can be passed to functions expecting a pool, and acquire() |
| 243 | + calls will be limited by the budget semaphore. |
| 244 | +
|
| 245 | + Usage: |
| 246 | + async with budgeted_operation(max_connections=4) as op: |
| 247 | + budgeted_pool = op.wrap_pool(pool) |
| 248 | + # Pass budgeted_pool to functions that expect a pool |
| 249 | + await some_function(budgeted_pool, ...) |
| 250 | + """ |
| 251 | + |
| 252 | + def __init__(self, pool: "asyncpg.Pool", operation: BudgetedOperation): |
| 253 | + self._pool = pool |
| 254 | + self._operation = operation |
| 255 | + |
| 256 | + async def acquire(self) -> "asyncpg.Connection": |
| 257 | + """ |
| 258 | + Acquire a connection within the budget. |
| 259 | +
|
| 260 | + Note: Caller must release the connection when done. |
| 261 | + Prefer using as context manager via acquire_with_retry or op.acquire(). |
| 262 | + """ |
| 263 | + budget = self._operation.budget |
| 264 | + await budget.semaphore.acquire() |
| 265 | + budget.active_count += 1 |
| 266 | + try: |
| 267 | + return await self._pool.acquire() |
| 268 | + except Exception: |
| 269 | + budget.active_count -= 1 |
| 270 | + budget.semaphore.release() |
| 271 | + raise |
| 272 | + |
| 273 | + async def release(self, conn: "asyncpg.Connection") -> None: |
| 274 | + """Release a connection back to the pool.""" |
| 275 | + budget = self._operation.budget |
| 276 | + try: |
| 277 | + await self._pool.release(conn) |
| 278 | + finally: |
| 279 | + budget.active_count -= 1 |
| 280 | + budget.semaphore.release() |
| 281 | + |
| 282 | + def __getattr__(self, name): |
| 283 | + """Proxy other attributes to the underlying pool.""" |
| 284 | + return getattr(self._pool, name) |
0 commit comments