Skip to content

Commit 7accac9

Browse files
authored
fix: migrate mental_models.embedding dimension alongside memory_units (#526)
ensure_embedding_dimension() now also checks and migrates mental_models.embedding, fixing silent failures when changing embedding model dimensions. Extracted shared per-table logic into _migrate_table_embedding_dimension() to avoid duplication. Adds test coverage for the mental_models dimension migration path. Fixes #523
1 parent fa3501d commit 7accac9

File tree

2 files changed

+203
-119
lines changed

2 files changed

+203
-119
lines changed

hindsight-api/hindsight_api/migrations.py

Lines changed: 125 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from alembic import command
2626
from alembic.config import Config
2727
from alembic.script.revision import ResolutionError
28-
from sqlalchemy import create_engine, text
28+
from sqlalchemy import Connection, create_engine, text
2929

3030
from .utils import mask_network_location
3131

@@ -471,17 +471,135 @@ def check_migration_status(
471471
return None, None
472472

473473

474+
def _migrate_table_embedding_dimension(
475+
conn: Connection,
476+
schema_name: str,
477+
table_name: str,
478+
required_dimension: int,
479+
vector_ext: str,
480+
) -> None:
481+
"""
482+
Migrate the embedding column of a single table to the required dimension.
483+
484+
- If dimensions match: no action needed
485+
- If dimensions differ and table is empty: ALTER COLUMN to new dimension
486+
- If dimensions differ and table has data: raise error with migration guidance
487+
"""
488+
current_dim = conn.execute(
489+
text("""
490+
SELECT atttypmod
491+
FROM pg_attribute a
492+
JOIN pg_class c ON a.attrelid = c.oid
493+
JOIN pg_namespace n ON c.relnamespace = n.oid
494+
WHERE n.nspname = :schema
495+
AND c.relname = :table
496+
AND a.attname = 'embedding'
497+
"""),
498+
{"schema": schema_name, "table": table_name},
499+
).scalar()
500+
501+
if current_dim is None:
502+
logger.debug(f"No embedding column found on {table_name}, skipping")
503+
return
504+
505+
if current_dim == required_dimension:
506+
logger.debug(f"Embedding dimension OK for {table_name}: {current_dim}")
507+
return
508+
509+
logger.info(
510+
f"Embedding dimension mismatch on {table_name}: database has {current_dim}, model requires {required_dimension}"
511+
)
512+
513+
row_count = conn.execute(
514+
text(f"SELECT COUNT(*) FROM {schema_name}.{table_name} WHERE embedding IS NOT NULL")
515+
).scalar()
516+
517+
if row_count > 0:
518+
raise RuntimeError(
519+
f"Cannot change embedding dimension from {current_dim} to {required_dimension}: "
520+
f"{table_name} table contains {row_count} rows with embeddings. "
521+
f"To change dimensions, you must either:\n"
522+
f" 1. Re-embed all data: DELETE FROM {schema_name}.{table_name}; then restart\n"
523+
f" 2. Use a model with {current_dim}-dimensional embeddings"
524+
)
525+
526+
logger.info(f"Altering {table_name}.embedding column dimension from {current_dim} to {required_dimension}")
527+
528+
# Drop existing vector index (works for both HNSW and vchordrq)
529+
conn.execute(
530+
text(f"""
531+
DO $$
532+
DECLARE idx_name TEXT;
533+
BEGIN
534+
FOR idx_name IN
535+
SELECT indexname FROM pg_indexes
536+
WHERE schemaname = '{schema_name}'
537+
AND tablename = '{table_name}'
538+
AND (indexdef LIKE '%hnsw%' OR indexdef LIKE '%vchordrq%' OR indexdef LIKE '%diskann%')
539+
AND indexdef LIKE '%embedding%'
540+
LOOP
541+
EXECUTE 'DROP INDEX IF EXISTS {schema_name}.' || idx_name;
542+
END LOOP;
543+
END $$;
544+
""")
545+
)
546+
547+
conn.execute(
548+
text(f"ALTER TABLE {schema_name}.{table_name} ALTER COLUMN embedding TYPE vector({required_dimension})")
549+
)
550+
conn.commit()
551+
552+
# Recreate index with appropriate type based on detected extension
553+
if vector_ext == "pgvectorscale":
554+
conn.execute(
555+
text(f"""
556+
CREATE INDEX IF NOT EXISTS idx_{table_name}_embedding_diskann
557+
ON {schema_name}.{table_name}
558+
USING diskann (embedding vector_cosine_ops)
559+
WITH (num_neighbors = 50)
560+
""")
561+
)
562+
logger.info(f"Created DiskANN index on {table_name} for {required_dimension}-dimensional embeddings")
563+
elif vector_ext == "vchord":
564+
conn.execute(
565+
text(f"""
566+
CREATE INDEX IF NOT EXISTS idx_{table_name}_embedding_vchordrq
567+
ON {schema_name}.{table_name}
568+
USING vchordrq (embedding vector_l2_ops)
569+
""")
570+
)
571+
logger.info(f"Created vchordrq index on {table_name} for {required_dimension}-dimensional embeddings")
572+
else: # pgvector
573+
if required_dimension > 2000:
574+
raise RuntimeError(
575+
f"Embedding dimension {required_dimension} exceeds pgvector HNSW index limit of 2000. "
576+
f"Use an embedding model with <= 2000 dimensions, or switch to a vector extension "
577+
f"that supports higher dimensions (e.g., pgvectorscale/DiskANN)."
578+
)
579+
conn.execute(
580+
text(f"""
581+
CREATE INDEX IF NOT EXISTS idx_{table_name}_embedding_hnsw
582+
ON {schema_name}.{table_name}
583+
USING hnsw (embedding vector_cosine_ops)
584+
WITH (m = 16, ef_construction = 64)
585+
""")
586+
)
587+
logger.info(f"Created HNSW index on {table_name} for {required_dimension}-dimensional embeddings")
588+
conn.commit()
589+
590+
logger.info(f"Successfully changed {table_name}.embedding dimension to {required_dimension}")
591+
592+
474593
def ensure_embedding_dimension(
475594
database_url: str,
476595
required_dimension: int,
477596
schema: str | None = None,
478597
vector_extension: str = "pgvector",
479598
) -> None:
480599
"""
481-
Ensure the embedding column dimension matches the model's dimension.
600+
Ensure the embedding column dimension matches the model's dimension for all tables.
482601
483-
This function checks the current vector column dimension in the database
484-
and adjusts it if necessary:
602+
Checks and adjusts memory_units.embedding and mental_models.embedding:
485603
- If dimensions match: no action needed
486604
- If dimensions differ and table is empty: ALTER COLUMN to new dimension
487605
- If dimensions differ and table has data: raise error with migration guidance
@@ -499,7 +617,7 @@ def ensure_embedding_dimension(
499617

500618
engine = create_engine(database_url)
501619
with engine.connect() as conn:
502-
# Check if memory_units table exists
620+
# Check if memory_units table exists (proxy for schema being initialized)
503621
table_exists = conn.execute(
504622
text("""
505623
SELECT EXISTS (
@@ -518,117 +636,8 @@ def ensure_embedding_dimension(
518636
vector_ext = _detect_vector_extension(conn, vector_extension)
519637
logger.info(f"Using vector extension: {vector_ext}")
520638

521-
# Get current column dimension from pg_attribute
522-
# pgvector stores dimension in atttypmod
523-
current_dim = conn.execute(
524-
text("""
525-
SELECT atttypmod
526-
FROM pg_attribute a
527-
JOIN pg_class c ON a.attrelid = c.oid
528-
JOIN pg_namespace n ON c.relnamespace = n.oid
529-
WHERE n.nspname = :schema
530-
AND c.relname = 'memory_units'
531-
AND a.attname = 'embedding'
532-
"""),
533-
{"schema": schema_name},
534-
).scalar()
535-
536-
if current_dim is None:
537-
logger.warning("Could not determine current embedding dimension, skipping check")
538-
return
539-
540-
# pgvector stores dimension directly in atttypmod (no offset like other types)
541-
current_dimension = current_dim
542-
543-
if current_dimension == required_dimension:
544-
logger.debug(f"Embedding dimension OK: {current_dimension}")
545-
return
546-
547-
logger.info(
548-
f"Embedding dimension mismatch: database has {current_dimension}, model requires {required_dimension}"
549-
)
550-
551-
# Check if table has data
552-
row_count = conn.execute(
553-
text(f"SELECT COUNT(*) FROM {schema_name}.memory_units WHERE embedding IS NOT NULL")
554-
).scalar()
555-
556-
if row_count > 0:
557-
raise RuntimeError(
558-
f"Cannot change embedding dimension from {current_dimension} to {required_dimension}: "
559-
f"memory_units table contains {row_count} rows with embeddings. "
560-
f"To change dimensions, you must either:\n"
561-
f" 1. Re-embed all data: DELETE FROM {schema_name}.memory_units; then restart\n"
562-
f" 2. Use a model with {current_dimension}-dimensional embeddings"
563-
)
564-
565-
# Table is empty, safe to alter column
566-
logger.info(f"Altering embedding column dimension from {current_dimension} to {required_dimension}")
567-
568-
# Drop existing vector index (works for both HNSW and vchordrq)
569-
conn.execute(
570-
text(f"""
571-
DO $$
572-
DECLARE idx_name TEXT;
573-
BEGIN
574-
FOR idx_name IN
575-
SELECT indexname FROM pg_indexes
576-
WHERE schemaname = '{schema_name}'
577-
AND tablename = 'memory_units'
578-
AND (indexdef LIKE '%hnsw%' OR indexdef LIKE '%vchordrq%')
579-
AND indexdef LIKE '%embedding%'
580-
LOOP
581-
EXECUTE 'DROP INDEX IF EXISTS {schema_name}.' || idx_name;
582-
END LOOP;
583-
END $$;
584-
""")
585-
)
586-
587-
# Alter the column type
588-
conn.execute(
589-
text(f"ALTER TABLE {schema_name}.memory_units ALTER COLUMN embedding TYPE vector({required_dimension})")
590-
)
591-
conn.commit()
592-
593-
# Recreate index with appropriate type based on detected extension
594-
if vector_ext == "pgvectorscale":
595-
conn.execute(
596-
text(f"""
597-
CREATE INDEX IF NOT EXISTS idx_memory_units_embedding_diskann
598-
ON {schema_name}.memory_units
599-
USING diskann (embedding vector_cosine_ops)
600-
WITH (num_neighbors = 50)
601-
""")
602-
)
603-
logger.info(f"Created DiskANN index for {required_dimension}-dimensional embeddings")
604-
elif vector_ext == "vchord":
605-
conn.execute(
606-
text(f"""
607-
CREATE INDEX IF NOT EXISTS idx_memory_units_embedding_vchordrq
608-
ON {schema_name}.memory_units
609-
USING vchordrq (embedding vector_l2_ops)
610-
""")
611-
)
612-
logger.info(f"Created vchordrq index for {required_dimension}-dimensional embeddings")
613-
else: # pgvector
614-
if required_dimension > 2000:
615-
raise RuntimeError(
616-
f"Embedding dimension {required_dimension} exceeds pgvector HNSW index limit of 2000. "
617-
f"Use an embedding model with <= 2000 dimensions, or switch to a vector extension "
618-
f"that supports higher dimensions (e.g., pgvectorscale/DiskANN)."
619-
)
620-
conn.execute(
621-
text(f"""
622-
CREATE INDEX IF NOT EXISTS idx_memory_units_embedding_hnsw
623-
ON {schema_name}.memory_units
624-
USING hnsw (embedding vector_cosine_ops)
625-
WITH (m = 16, ef_construction = 64)
626-
""")
627-
)
628-
logger.info(f"Created HNSW index for {required_dimension}-dimensional embeddings")
629-
conn.commit()
630-
631-
logger.info(f"Successfully changed embedding dimension to {required_dimension}")
639+
_migrate_table_embedding_dimension(conn, schema_name, "memory_units", required_dimension, vector_ext)
640+
_migrate_table_embedding_dimension(conn, schema_name, "mental_models", required_dimension, vector_ext)
632641

633642

634643
def ensure_vector_extension(

hindsight-api/tests/test_custom_embedding_dimension.py

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def drop_schema(db_url: str, schema_name: str):
7575
conn.commit()
7676

7777

78-
def get_column_dimension(db_url: str, schema: str = "public") -> int | None:
78+
def get_column_dimension(db_url: str, schema: str = "public", table: str = "memory_units") -> int | None:
7979
"""Get the current embedding column dimension from the database."""
8080
engine = create_engine(db_url)
8181
with engine.connect() as conn:
@@ -86,10 +86,10 @@ def get_column_dimension(db_url: str, schema: str = "public") -> int | None:
8686
JOIN pg_class c ON a.attrelid = c.oid
8787
JOIN pg_namespace n ON c.relnamespace = n.oid
8888
WHERE n.nspname = :schema
89-
AND c.relname = 'memory_units'
89+
AND c.relname = :table
9090
AND a.attname = 'embedding'
9191
"""),
92-
{"schema": schema},
92+
{"schema": schema, "table": table},
9393
).scalar()
9494
return result
9595

@@ -125,6 +125,38 @@ def clear_embeddings(db_url: str, schema: str):
125125
conn.commit()
126126

127127

128+
def insert_test_mental_model_embedding(db_url: str, schema: str, dimension: int):
129+
"""Insert a test mental model row with a dummy embedding."""
130+
engine = create_engine(db_url)
131+
embedding = [0.1] * dimension
132+
embedding_str = "[" + ",".join(str(x) for x in embedding) + "]"
133+
134+
with engine.connect() as conn:
135+
# Ensure test bank exists
136+
conn.execute(
137+
text(f"""
138+
INSERT INTO {schema}.banks (bank_id, name)
139+
VALUES ('test-bank-mm', 'Test Bank')
140+
ON CONFLICT (bank_id) DO NOTHING
141+
""")
142+
)
143+
conn.execute(
144+
text(f"""
145+
INSERT INTO {schema}.mental_models (bank_id, name, source_query, content, embedding)
146+
VALUES ('test-bank-mm', 'test model', 'test query', 'test content', '{embedding_str}'::vector)
147+
""")
148+
)
149+
conn.commit()
150+
151+
152+
def clear_mental_model_embeddings(db_url: str, schema: str):
153+
"""Clear all rows from mental_models."""
154+
engine = create_engine(db_url)
155+
with engine.connect() as conn:
156+
conn.execute(text(f"DELETE FROM {schema}.mental_models"))
157+
conn.commit()
158+
159+
128160
# =============================================================================
129161
# Embedding Dimension Tests (Local Embeddings)
130162
# =============================================================================
@@ -199,6 +231,49 @@ def test_dimension_change_blocked_with_data(self, dimension_test_schema):
199231
# Cleanup
200232
clear_embeddings(db_url, schema)
201233

234+
def test_mental_models_dimension_matches_no_change(self, dimension_test_schema):
235+
"""When mental_models dimension matches, no changes should be made."""
236+
db_url, schema = dimension_test_schema
237+
238+
initial_dim = get_column_dimension(db_url, schema, table="mental_models")
239+
assert initial_dim == 384, f"Expected 384, got {initial_dim}"
240+
241+
ensure_embedding_dimension(db_url, 384, schema=schema)
242+
243+
assert get_column_dimension(db_url, schema, table="mental_models") == 384
244+
245+
def test_mental_models_dimension_change_empty_table(self, dimension_test_schema):
246+
"""When mental_models is empty, dimension can be changed."""
247+
db_url, schema = dimension_test_schema
248+
249+
clear_mental_model_embeddings(db_url, schema)
250+
251+
ensure_embedding_dimension(db_url, 768, schema=schema)
252+
253+
assert get_column_dimension(db_url, schema, table="mental_models") == 768
254+
255+
# Change back for other tests
256+
ensure_embedding_dimension(db_url, 384, schema=schema)
257+
assert get_column_dimension(db_url, schema, table="mental_models") == 384
258+
259+
def test_mental_models_dimension_change_blocked_with_data(self, dimension_test_schema):
260+
"""When mental_models has data, dimension change should be blocked."""
261+
db_url, schema = dimension_test_schema
262+
263+
clear_mental_model_embeddings(db_url, schema)
264+
insert_test_mental_model_embedding(db_url, schema, 384)
265+
266+
with pytest.raises(RuntimeError) as exc_info:
267+
ensure_embedding_dimension(db_url, 768, schema=schema)
268+
269+
assert "Cannot change embedding dimension" in str(exc_info.value)
270+
assert "mental_models" in str(exc_info.value)
271+
272+
assert get_column_dimension(db_url, schema, table="mental_models") == 384
273+
274+
# Cleanup
275+
clear_mental_model_embeddings(db_url, schema)
276+
202277
def test_local_embeddings_dimension_detection(self, embeddings):
203278
"""Test that LocalSTEmbeddings correctly detects dimension."""
204279
# Initialize embeddings if not already done

0 commit comments

Comments
 (0)