Skip to content

Commit 3825506

Browse files
authored
fix: custom pg schema is not reliable (#278)
* fix: custom pg schema is not reliable * fix * fix * fix: WorkerPoller now always has tenant extension Ensures WorkerPoller follows same pattern as MemoryEngine - always creates a DefaultTenantExtension if none is provided, preventing NoneType errors when calling list_tenants(). Fixes test failures in test_worker.py * fix: DefaultTenantExtension honors explicit schema parameter Allows WorkerPoller's schema parameter to be passed through to DefaultTenantExtension via config dict, maintaining backward compatibility for tests that use schema parameter without providing a tenant extension. Fixes test_poller_with_custom_schema test failure.
1 parent 6c7f057 commit 3825506

File tree

7 files changed

+205
-54
lines changed

7 files changed

+205
-54
lines changed

hindsight-api/hindsight_api/alembic/versions/5a366d414dce_initial_schema.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import sqlalchemy as sa
1212
from alembic import op
1313
from pgvector.sqlalchemy import Vector
14+
from sqlalchemy import text
1415
from sqlalchemy.dialects import postgresql
1516

1617
# revision identifiers, used by Alembic.
@@ -23,8 +24,21 @@
2324
def upgrade() -> None:
2425
"""Upgrade schema - create all tables from scratch."""
2526

26-
# Enable required extensions
27-
op.execute("CREATE EXTENSION IF NOT EXISTS vector")
27+
# Note: pgvector extension is installed globally BEFORE migrations run
28+
# See migrations.py:run_migrations() - this ensures the extension is available
29+
# to all schemas, not just the one being migrated
30+
31+
# We keep this here as a fallback for backwards compatibility
32+
# This may fail if user lacks permissions, which is fine if extension already exists
33+
try:
34+
op.execute("CREATE EXTENSION IF NOT EXISTS vector")
35+
except Exception:
36+
# Extension might already exist or user lacks permissions - verify it exists
37+
conn = op.get_bind()
38+
result = conn.execute(text("SELECT 1 FROM pg_extension WHERE extname = 'vector'")).fetchone()
39+
if not result:
40+
# Extension truly doesn't exist - re-raise the error
41+
raise
2842

2943
# Create banks table
3044
op.create_table(

hindsight-api/hindsight_api/api/http.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1410,7 +1410,7 @@ async def lifespan(app: FastAPI):
14101410
poll_interval_ms=config.worker_poll_interval_ms,
14111411
max_retries=config.worker_max_retries,
14121412
schema=schema,
1413-
tenant_extension=getattr(memory, "_tenant_extension", None),
1413+
tenant_extension=memory._tenant_extension,
14141414
max_slots=config.worker_max_slots,
14151415
consolidation_max_slots=config.worker_consolidation_max_slots,
14161416
)

hindsight-api/hindsight_api/engine/memory_engine.py

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,11 @@ def __init__(
459459
# Store operation validator extension (optional)
460460
self._operation_validator = operation_validator
461461

462-
# Store tenant extension (optional)
462+
# Store tenant extension (always set, use default if none provided)
463+
if tenant_extension is None:
464+
from ..extensions.builtin.tenant import DefaultTenantExtension
465+
466+
tenant_extension = DefaultTenantExtension(config={})
463467
self._tenant_extension = tenant_extension
464468

465469
async def _validate_operation(self, validation_coro) -> None:
@@ -497,22 +501,18 @@ async def _authenticate_tenant(self, request_context: "RequestContext | None") -
497501
Raises:
498502
AuthenticationError: If authentication fails or request_context is missing when required.
499503
"""
500-
if self._tenant_extension is None:
501-
_current_schema.set("public")
502-
return "public"
503-
504504
from hindsight_api.extensions import AuthenticationError
505505

506506
if request_context is None:
507-
raise AuthenticationError("RequestContext is required when tenant extension is configured")
507+
raise AuthenticationError("RequestContext is required")
508508

509509
# For internal/background operations (e.g., worker tasks), skip extension authentication.
510510
# The task was already authenticated at submission time, and execute_task sets _current_schema
511-
# from the task's _schema field. For public schema tasks, _current_schema keeps its default "public".
511+
# from the task's _schema field.
512512
if request_context.internal:
513513
return _current_schema.get()
514514

515-
# Let AuthenticationError propagate - HTTP layer will convert to 401
515+
# Authenticate through tenant extension (always set, may be default no-auth extension)
516516
tenant_context = await self._tenant_extension.authenticate(request_context)
517517

518518
_current_schema.set(tenant_context.schema_name)
@@ -939,30 +939,34 @@ async def verify_llm():
939939

940940
if not self.db_url:
941941
raise ValueError("Database URL is required for migrations")
942-
logger.info("Running database migrations...")
943-
# Use configured database schema for migrations (defaults to "public")
944-
run_migrations(self.db_url, schema=get_config().database_schema)
945-
946-
# Migrate all existing tenant schemas (if multi-tenant)
947-
if self._tenant_extension is not None:
948-
try:
949-
tenants = await self._tenant_extension.list_tenants()
950-
if tenants:
951-
logger.info(f"Running migrations on {len(tenants)} tenant schemas...")
952-
for tenant in tenants:
953-
schema = tenant.schema
954-
if schema and schema != "public":
955-
try:
956-
run_migrations(self.db_url, schema=schema)
957-
except Exception as e:
958-
logger.warning(f"Failed to migrate tenant schema {schema}: {e}")
959-
logger.info("Tenant schema migrations completed")
960-
except Exception as e:
961-
logger.warning(f"Failed to run tenant schema migrations: {e}")
962942

963-
# Ensure embedding column dimension matches the model's dimension
964-
# This is done after migrations and after embeddings.initialize()
965-
ensure_embedding_dimension(self.db_url, self.embeddings.dimension, schema=get_config().database_schema)
943+
# Migrate all schemas from the tenant extension
944+
# The tenant extension is the single source of truth for which schemas exist
945+
logger.info("Running database migrations...")
946+
try:
947+
tenants = await self._tenant_extension.list_tenants()
948+
if tenants:
949+
logger.info(f"Running migrations on {len(tenants)} schema(s)...")
950+
for tenant in tenants:
951+
schema = tenant.schema
952+
if schema:
953+
try:
954+
run_migrations(self.db_url, schema=schema)
955+
except Exception as e:
956+
logger.warning(f"Failed to migrate schema {schema}: {e}")
957+
logger.info("Schema migrations completed")
958+
959+
# Ensure embedding column dimension matches the model's dimension
960+
# This is done after migrations and after embeddings.initialize()
961+
for tenant in tenants:
962+
schema = tenant.schema
963+
if schema:
964+
try:
965+
ensure_embedding_dimension(self.db_url, self.embeddings.dimension, schema=schema)
966+
except Exception as e:
967+
logger.warning(f"Failed to ensure embedding dimension for schema {schema}: {e}")
968+
except Exception as e:
969+
logger.warning(f"Failed to run schema migrations: {e}")
966970

967971
logger.info(f"Connecting to PostgreSQL at {self.db_url}")
968972

hindsight-api/hindsight_api/extensions/builtin/tenant.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,42 @@
55
from hindsight_api.models import RequestContext
66

77

8+
class DefaultTenantExtension(TenantExtension):
9+
"""
10+
Default single-tenant extension with no authentication.
11+
12+
This is the default extension used when no tenant extension is configured.
13+
It provides single-tenant behavior using the configured schema from
14+
HINDSIGHT_API_DATABASE_SCHEMA (defaults to 'public').
15+
16+
Features:
17+
- No authentication required (passes all requests)
18+
- Uses configured schema from environment
19+
- Perfect for single-tenant deployments without auth
20+
21+
Configuration:
22+
HINDSIGHT_API_DATABASE_SCHEMA=your-schema (optional, defaults to 'public')
23+
24+
This is automatically enabled by default. To use custom authentication,
25+
configure a different tenant extension:
26+
HINDSIGHT_API_TENANT_EXTENSION=hindsight_api.extensions.builtin.tenant:ApiKeyTenantExtension
27+
"""
28+
29+
def __init__(self, config: dict[str, str]):
30+
super().__init__(config)
31+
# Cache the schema at initialization for consistency
32+
# Support explicit schema override via config, otherwise use environment
33+
self._schema = config.get("schema", get_config().database_schema)
34+
35+
async def authenticate(self, context: RequestContext) -> TenantContext:
36+
"""Return configured schema without any authentication."""
37+
return TenantContext(schema_name=self._schema)
38+
39+
async def list_tenants(self) -> list[Tenant]:
40+
"""Return configured schema for single-tenant setup."""
41+
return [Tenant(schema=self._schema)]
42+
43+
844
class ApiKeyTenantExtension(TenantExtension):
945
"""
1046
Built-in tenant extension that validates API key against an environment variable.

hindsight-api/hindsight_api/migrations.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,81 @@ def run_migrations(
165165
logger.debug("Migration advisory lock acquired")
166166

167167
try:
168+
# Ensure pgvector extension is installed globally BEFORE schema migrations
169+
# This is critical: the extension must exist database-wide before any schema
170+
# migrations run, otherwise custom schemas won't have access to vector types
171+
logger.debug("Checking pgvector extension availability...")
172+
173+
# First, check if extension already exists
174+
ext_check = conn.execute(
175+
text(
176+
"SELECT extname, nspname FROM pg_extension e "
177+
"JOIN pg_namespace n ON e.extnamespace = n.oid "
178+
"WHERE extname = 'vector'"
179+
)
180+
).fetchone()
181+
182+
if ext_check:
183+
# Extension exists - check if in correct schema
184+
ext_schema = ext_check[1]
185+
if ext_schema == "public":
186+
logger.info("pgvector extension found in public schema - ready to use")
187+
else:
188+
# Extension in wrong schema - try to fix if we have permissions
189+
logger.warning(
190+
f"pgvector extension found in schema '{ext_schema}' instead of 'public'. "
191+
f"Attempting to relocate..."
192+
)
193+
try:
194+
conn.execute(text("DROP EXTENSION vector CASCADE"))
195+
conn.execute(text("SET search_path TO public"))
196+
conn.execute(text("CREATE EXTENSION vector"))
197+
conn.commit()
198+
logger.info("pgvector extension relocated to public schema")
199+
except Exception as e:
200+
# Failed to relocate - log but don't fail if extension exists somewhere
201+
logger.warning(
202+
f"Could not relocate pgvector extension to public schema: {e}. "
203+
f"Continuing with extension in '{ext_schema}' schema."
204+
)
205+
conn.rollback()
206+
else:
207+
# Extension doesn't exist - try to install
208+
logger.info("pgvector extension not found, attempting to install...")
209+
try:
210+
conn.execute(text("SET search_path TO public"))
211+
conn.execute(text("CREATE EXTENSION vector"))
212+
conn.commit()
213+
logger.info("pgvector extension installed in public schema")
214+
except Exception as e:
215+
# Installation failed - this is only fatal if extension truly doesn't exist
216+
# Check one more time in case another process installed it
217+
conn.rollback()
218+
ext_recheck = conn.execute(
219+
text(
220+
"SELECT nspname FROM pg_extension e "
221+
"JOIN pg_namespace n ON e.extnamespace = n.oid "
222+
"WHERE extname = 'vector'"
223+
)
224+
).fetchone()
225+
226+
if ext_recheck:
227+
logger.warning(
228+
f"Could not install pgvector extension (permission denied?), "
229+
f"but extension exists in '{ext_recheck[0]}' schema. Continuing..."
230+
)
231+
else:
232+
# Extension truly doesn't exist and we can't install it
233+
logger.error(
234+
f"pgvector extension is not installed and cannot be installed: {e}. "
235+
f"Please ensure pgvector is installed by a database administrator. "
236+
f"See: https://github.com/pgvector/pgvector#installation"
237+
)
238+
raise RuntimeError(
239+
"pgvector extension is required but not installed. "
240+
"Please install it with: CREATE EXTENSION vector;"
241+
) from e
242+
168243
# Run migrations while holding the lock
169244
_run_migrations_internal(database_url, script_location, schema=schema)
170245
finally:

hindsight-api/hindsight_api/worker/main.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -222,15 +222,30 @@ async def run():
222222
# Create the HTTP app for metrics/health
223223
app = create_worker_app(poller, memory)
224224

225-
# Setup signal handlers for graceful shutdown
225+
# Setup signal handlers for graceful shutdown using asyncio
226226
shutdown_requested = asyncio.Event()
227-
228-
def signal_handler(signum, frame):
229-
print(f"\nReceived signal {signum}, initiating graceful shutdown...")
230-
shutdown_requested.set()
231-
232-
signal.signal(signal.SIGINT, signal_handler)
233-
signal.signal(signal.SIGTERM, signal_handler)
227+
force_exit = False
228+
229+
loop = asyncio.get_event_loop()
230+
231+
def signal_handler():
232+
nonlocal force_exit
233+
if shutdown_requested.is_set():
234+
# Second signal = force exit
235+
print("\nReceived second signal, forcing immediate exit...")
236+
force_exit = True
237+
# Restore default handler so third signal kills process
238+
loop.remove_signal_handler(signal.SIGINT)
239+
loop.remove_signal_handler(signal.SIGTERM)
240+
sys.exit(1)
241+
else:
242+
print("\nReceived shutdown signal, initiating graceful shutdown...")
243+
print("(Press Ctrl+C again to force immediate exit)")
244+
shutdown_requested.set()
245+
246+
# Use asyncio's signal handlers which work properly with the event loop
247+
loop.add_signal_handler(signal.SIGINT, signal_handler)
248+
loop.add_signal_handler(signal.SIGTERM, signal_handler)
234249

235250
# Create uvicorn config and server
236251
uvicorn_config = uvicorn.Config(
@@ -249,7 +264,10 @@ def signal_handler(signum, frame):
249264
print(f"Worker started. Metrics available at http://{args.http_host}:{args.http_port}/metrics")
250265

251266
# Wait for shutdown signal
252-
await shutdown_requested.wait()
267+
try:
268+
await shutdown_requested.wait()
269+
except KeyboardInterrupt:
270+
print("\nReceived interrupt, initiating graceful shutdown...")
253271

254272
# Graceful shutdown
255273
print("Shutting down HTTP server...")

hindsight-api/hindsight_api/worker/poller.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,9 @@ def __init__(
7272
executor: Async function to execute tasks (typically MemoryEngine.execute_task)
7373
poll_interval_ms: Interval between polls when no tasks found (milliseconds)
7474
max_retries: Maximum retry attempts before marking task as failed
75-
schema: Database schema for single-tenant support (ignored if tenant_extension is set)
76-
tenant_extension: Extension for dynamic multi-tenant discovery. If set, list_tenants()
77-
is called on each poll cycle to discover schemas dynamically.
75+
schema: Database schema for single-tenant support (deprecated, use tenant_extension)
76+
tenant_extension: Extension for dynamic multi-tenant discovery. If None, creates a
77+
DefaultTenantExtension with the configured schema.
7878
max_slots: Maximum concurrent tasks per worker
7979
consolidation_max_slots: Maximum concurrent consolidation tasks per worker
8080
"""
@@ -84,6 +84,13 @@ def __init__(
8484
self._poll_interval_ms = poll_interval_ms
8585
self._max_retries = max_retries
8686
self._schema = schema
87+
# Always set tenant extension (use DefaultTenantExtension if none provided)
88+
if tenant_extension is None:
89+
from ..extensions.builtin.tenant import DefaultTenantExtension
90+
91+
# Pass schema parameter to DefaultTenantExtension if explicitly provided
92+
config = {"schema": schema} if schema else {}
93+
tenant_extension = DefaultTenantExtension(config=config)
8794
self._tenant_extension = tenant_extension
8895
self._max_slots = max_slots
8996
self._consolidation_max_slots = consolidation_max_slots
@@ -100,14 +107,11 @@ def __init__(
100107

101108
async def _get_schemas(self) -> list[str | None]:
102109
"""Get list of schemas to poll. Returns [None] for default schema (no prefix)."""
103-
if self._tenant_extension is not None:
104-
from ..config import DEFAULT_DATABASE_SCHEMA
105-
106-
tenants = await self._tenant_extension.list_tenants()
107-
# Convert default schema to None for SQL compatibility (no prefix), keep others as-is
108-
return [t.schema if t.schema != DEFAULT_DATABASE_SCHEMA else None for t in tenants]
109-
# Single schema mode
110-
return [self._schema]
110+
from ..config import DEFAULT_DATABASE_SCHEMA
111+
112+
tenants = await self._tenant_extension.list_tenants()
113+
# Convert default schema to None for SQL compatibility (no prefix), keep others as-is
114+
return [t.schema if t.schema != DEFAULT_DATABASE_SCHEMA else None for t in tenants]
111115

112116
async def _get_available_slots(self) -> tuple[int, int]:
113117
"""

0 commit comments

Comments
 (0)