From b33609417039c484678ea946914eba58aed96a45 Mon Sep 17 00:00:00 2001 From: viragtripathi Date: Thu, 6 Nov 2025 21:23:14 -0500 Subject: [PATCH] feat: Add CockroachDB vector database support - UUID PRIMARY KEY for distributed write performance - Connection pooling with 100+ base connections - Comprehensive retry logic for transient errors (40001, 40003) - C-SPANN vector index with tunable parameters - CLI integration with full parameter support --- README.md | 43 +- pyproject.toml | 1 + tests/test_cockroachdb.py | 128 +++++ vectordb_bench/backend/clients/__init__.py | 16 + .../backend/clients/cockroachdb/cli.py | 110 +++++ .../clients/cockroachdb/cockroachdb.py | 439 ++++++++++++++++++ .../backend/clients/cockroachdb/config.py | 202 ++++++++ .../backend/clients/cockroachdb/db_retry.py | 111 +++++ vectordb_bench/cli/vectordbbench.py | 2 + 9 files changed, 1031 insertions(+), 21 deletions(-) create mode 100644 tests/test_cockroachdb.py create mode 100644 vectordb_bench/backend/clients/cockroachdb/cli.py create mode 100644 vectordb_bench/backend/clients/cockroachdb/cockroachdb.py create mode 100644 vectordb_bench/backend/clients/cockroachdb/config.py create mode 100644 vectordb_bench/backend/clients/cockroachdb/db_retry.py diff --git a/README.md b/README.md index 9b69e95a0..6bba33694 100644 --- a/README.md +++ b/README.md @@ -39,26 +39,27 @@ pip install 'vectordb-bench[pinecone]' ``` All the database client supported -| Optional database client | install command | -|--------------------------|---------------------------------------------| -| pymilvus, zilliz_cloud (*default*) | `pip install vectordb-bench` | -| all (*clients requirements might be conflict with each other*) | `pip install 'vectordb-bench[all]'` | -| qdrant | `pip install 'vectordb-bench[qdrant]'` | -| pinecone | `pip install 'vectordb-bench[pinecone]'` | -| weaviate | `pip install 'vectordb-bench[weaviate]'` | -| elastic, aliyun_elasticsearch| `pip install 'vectordb-bench[elastic]'` | -| pgvector, pgvectorscale, pgdiskann, alloydb | `pip install 'vectordb-bench[pgvector]'` | -| pgvecto.rs | `pip install 'vectordb-bench[pgvecto_rs]'` | -| redis | `pip install 'vectordb-bench[redis]'` | -| memorydb | `pip install 'vectordb-bench[memorydb]'` | -| chromadb | `pip install 'vectordb-bench[chromadb]'` | -| awsopensearch | `pip install 'vectordb-bench[opensearch]'` | -| aliyun_opensearch | `pip install 'vectordb-bench[aliyun_opensearch]'` | -| mongodb | `pip install 'vectordb-bench[mongodb]'` | -| tidb | `pip install 'vectordb-bench[tidb]'` | -| vespa | `pip install 'vectordb-bench[vespa]'` | -| oceanbase | `pip install 'vectordb-bench[oceanbase]'` | -| hologres | `pip install 'vectordb-bench[hologres]'` | +| Optional database client | install command | +|----------------------------------------------------------------|---------------------------------------------------| +| pymilvus, zilliz_cloud (*default*) | `pip install vectordb-bench` | +| all (*clients requirements might be conflict with each other*) | `pip install 'vectordb-bench[all]'` | +| qdrant | `pip install 'vectordb-bench[qdrant]'` | +| pinecone | `pip install 'vectordb-bench[pinecone]'` | +| weaviate | `pip install 'vectordb-bench[weaviate]'` | +| elastic, aliyun_elasticsearch | `pip install 'vectordb-bench[elastic]'` | +| pgvector, pgvectorscale, pgdiskann, alloydb | `pip install 'vectordb-bench[pgvector]'` | +| pgvecto.rs | `pip install 'vectordb-bench[pgvecto_rs]'` | +| redis | `pip install 'vectordb-bench[redis]'` | +| memorydb | `pip install 'vectordb-bench[memorydb]'` | +| chromadb | `pip install 'vectordb-bench[chromadb]'` | +| awsopensearch | `pip install 'vectordb-bench[opensearch]'` | +| aliyun_opensearch | `pip install 'vectordb-bench[aliyun_opensearch]'` | +| mongodb | `pip install 'vectordb-bench[mongodb]'` | +| tidb | `pip install 'vectordb-bench[tidb]'` | +| cockroachdb | `pip install 'vectordb-bench[cockroachdb]'` | +| vespa | `pip install 'vectordb-bench[vespa]'` | +| oceanbase | `pip install 'vectordb-bench[oceanbase]'` | +| hologres | `pip install 'vectordb-bench[hologres]'` | ### Run @@ -477,7 +478,7 @@ Now we can only run one task at the same time. ### Code Structure ![image](https://github.com/zilliztech/VectorDBBench/assets/105927039/8c06512e-5419-4381-b084-9c93aed59639) ### Client -Our client module is designed with flexibility and extensibility in mind, aiming to integrate APIs from different systems seamlessly. As of now, it supports Milvus, Zilliz Cloud, Elastic Search, Pinecone, Qdrant Cloud, Weaviate Cloud, PgVector, Redis, Chroma, etc. Stay tuned for more options, as we are consistently working on extending our reach to other systems. +Our client module is designed with flexibility and extensibility in mind, aiming to integrate APIs from different systems seamlessly. As of now, it supports Milvus, Zilliz Cloud, Elastic Search, Pinecone, Qdrant Cloud, Weaviate Cloud, PgVector, Redis, Chroma, CockroachDB, etc. Stay tuned for more options, as we are consistently working on extending our reach to other systems. ### Benchmark Cases We've developed lots of comprehensive benchmark cases to test vector databases' various capabilities, each designed to give you a different piece of the puzzle. These cases are categorized into four main types: #### Capacity Case diff --git a/pyproject.toml b/pyproject.toml index a922c1fb5..718a13fbf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,6 +94,7 @@ aliyun_opensearch = [ "alibabacloud_ha3engine_vector" ] mongodb = [ "pymongo" ] mariadb = [ "mariadb" ] tidb = [ "PyMySQL" ] +cockroachdb = [ "psycopg[binary,pool]", "pgvector" ] clickhouse = [ "clickhouse-connect" ] vespa = [ "pyvespa" ] lancedb = [ "lancedb" ] diff --git a/tests/test_cockroachdb.py b/tests/test_cockroachdb.py new file mode 100644 index 000000000..b55cded88 --- /dev/null +++ b/tests/test_cockroachdb.py @@ -0,0 +1,128 @@ +""" +Tests for CockroachDB vector database client. + +Assumes CockroachDB is running on localhost:26257. + +To start CockroachDB locally: + cockroach start-single-node --insecure --listen-addr=localhost:26257 +""" + +import logging + +import numpy as np + +from vectordb_bench.models import DB + +log = logging.getLogger(__name__) + + +class TestCockroachDB: + """Test suite for CockroachDB vector operations.""" + + def test_insert_and_search(self): + """Test basic insert and search operations.""" + assert DB.CockroachDB.value == "CockroachDB" + + dbcls = DB.CockroachDB.init_cls + dbConfig = DB.CockroachDB.config_cls + + # Connection config (matches your local CockroachDB instance) + config = { + "host": "localhost", + "port": 26257, + "user_name": "root", + "password": "", + "db_name": "defaultdb", + "table_name": "test_cockroachdb", + } + + # Note: sslmode=disable is handled in the client's connect_config options + + dim = 128 + count = 1000 + + # Initialize CockroachDB client + cockroachdb = dbcls( + dim=dim, + db_config=config, + db_case_config=None, + collection_name="test_cockroachdb", + drop_old=True, + ) + + embeddings = [[np.random.random() for _ in range(dim)] for _ in range(count)] + + # Test insert + with cockroachdb.init(): + res = cockroachdb.insert_embeddings(embeddings=embeddings, metadata=list(range(count))) + + assert res[0] == count, f"Insert count mismatch: {res[0]} != {count}" + assert res[1] is None, f"Insert failed with error: {res[1]}" + + # Test search + with cockroachdb.init(): + test_id = np.random.randint(count) + q = embeddings[test_id] + + res = cockroachdb.search_embedding(query=q, k=10) + + assert len(res) > 0, "Search returned no results" + assert res[0] == int(test_id), f"Top result {res[0]} != query id {test_id}" + + log.info("CockroachDB insert and search test passed") + + def test_search_with_filter(self): + """Test search with filters.""" + assert DB.CockroachDB.value == "CockroachDB" + + dbcls = DB.CockroachDB.init_cls + + config = { + "host": "localhost", + "port": 26257, + "user_name": "root", + "password": "", + "db_name": "defaultdb", + "table_name": "test_cockroachdb_filter", + } + + dim = 128 + count = 1000 + filter_value = 0.9 + + cockroachdb = dbcls( + dim=dim, + db_config=config, + db_case_config=None, + collection_name="test_cockroachdb_filter", + drop_old=True, + ) + + embeddings = [[np.random.random() for _ in range(dim)] for _ in range(count)] + + # Insert data + with cockroachdb.init(): + res = cockroachdb.insert_embeddings(embeddings=embeddings, metadata=list(range(count))) + assert res[0] == count, f"Insert count mismatch" + + # Search with filter + with cockroachdb.init(): + filter_id = int(count * filter_value) + test_id = np.random.randint(filter_id, count) + q = embeddings[test_id] + + from vectordb_bench.backend.filter import IntFilter + + filters = IntFilter(int_value=filter_id, filter_rate=0.9) + cockroachdb.prepare_filter(filters) + + res = cockroachdb.search_embedding(query=q, k=10) + + assert len(res) > 0, "Filtered search returned no results" + assert res[0] == int(test_id), f"Top result {res[0]} != query id {test_id}" + + # Verify all results are >= filter_value + for result_id in res: + assert int(result_id) >= filter_id, f"Result {result_id} < filter threshold {filter_id}" + + log.info("CockroachDB filter test passed") diff --git a/vectordb_bench/backend/clients/__init__.py b/vectordb_bench/backend/clients/__init__.py index 79a6f964a..c83a2fcc6 100644 --- a/vectordb_bench/backend/clients/__init__.py +++ b/vectordb_bench/backend/clients/__init__.py @@ -45,6 +45,7 @@ class DB(Enum): AliyunOpenSearch = "AliyunOpenSearch" MongoDB = "MongoDB" TiDB = "TiDB" + CockroachDB = "CockroachDB" Clickhouse = "Clickhouse" Vespa = "Vespa" LanceDB = "LanceDB" @@ -175,6 +176,11 @@ def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901, PLR0915 return TiDB + if self == DB.CockroachDB: + from .cockroachdb.cockroachdb import CockroachDB + + return CockroachDB + if self == DB.Test: from .test.test import Test @@ -326,6 +332,11 @@ def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912, C901, PLR0915 return TiDBConfig + if self == DB.CockroachDB: + from .cockroachdb.config import CockroachDBConfig + + return CockroachDBConfig + if self == DB.Test: from .test.config import TestConfig @@ -458,6 +469,11 @@ def case_config_cls( # noqa: C901, PLR0911, PLR0912 return TiDBIndexConfig + if self == DB.CockroachDB: + from .cockroachdb.config import _cockroachdb_case_config + + return _cockroachdb_case_config.get(index_type) + if self == DB.Vespa: from .vespa.config import VespaHNSWConfig diff --git a/vectordb_bench/backend/clients/cockroachdb/cli.py b/vectordb_bench/backend/clients/cockroachdb/cli.py new file mode 100644 index 000000000..1a21d7921 --- /dev/null +++ b/vectordb_bench/backend/clients/cockroachdb/cli.py @@ -0,0 +1,110 @@ +"""CLI parameter definitions for CockroachDB.""" + +from typing import Annotated, Unpack + +import click +from pydantic import SecretStr + +from vectordb_bench.backend.clients import DB + +from ....cli.cli import ( + CommonTypedDict, + cli, + click_parameter_decorators_from_typed_dict, + get_custom_case_config, + run, +) + + +class CockroachDBTypedDict(CommonTypedDict): + """Type definition for CockroachDB CLI parameters.""" + + user_name: Annotated[ + str, + click.option("--user-name", type=str, help="CockroachDB username", default="root", show_default=True), + ] + password: Annotated[ + str, + click.option("--password", type=str, help="CockroachDB password", default="", show_default=False), + ] + host: Annotated[ + str, + click.option("--host", type=str, help="CockroachDB host", required=True), + ] + port: Annotated[ + int, + click.option("--port", type=int, help="CockroachDB port", default=26257, show_default=True), + ] + db_name: Annotated[ + str, + click.option("--db-name", type=str, help="Database name", required=True), + ] + min_partition_size: Annotated[ + int | None, + click.option( + "--min-partition-size", + type=int, + help="Minimum vectors per partition (default: 16, range: 1-1024)", + default=16, + show_default=True, + ), + ] + max_partition_size: Annotated[ + int | None, + click.option( + "--max-partition-size", + type=int, + help="Maximum vectors per partition (default: 128, range: 4x min-4096)", + default=128, + show_default=True, + ), + ] + vector_search_beam_size: Annotated[ + int | None, + click.option( + "--vector-search-beam-size", + type=int, + help="Partitions explored during search (default: 32)", + default=32, + show_default=True, + ), + ] + + +@cli.command() +@click_parameter_decorators_from_typed_dict(CockroachDBTypedDict) +def CockroachDB( + **parameters: Unpack[CockroachDBTypedDict], +): + """Run CockroachDB vector benchmark.""" + from .config import CockroachDBConfig, CockroachDBVectorIndexConfig + + parameters["custom_case"] = get_custom_case_config(parameters) + + from vectordb_bench.backend.clients.api import MetricType + + # Use provided metric_type or default to COSINE + metric_type = parameters.get("metric_type") + if metric_type is None: + metric_type = MetricType.COSINE + elif isinstance(metric_type, str): + metric_type = MetricType(metric_type) + + run( + db=DB.CockroachDB, + db_config=CockroachDBConfig( + db_label=parameters["db_label"], + user_name=SecretStr(parameters["user_name"]), + password=SecretStr(parameters["password"]) if parameters["password"] else None, + host=parameters["host"], + port=parameters["port"], + db_name=parameters["db_name"], + ), + db_case_config=CockroachDBVectorIndexConfig( + metric_type=metric_type, + min_partition_size=parameters.get("min_partition_size", 16), + max_partition_size=parameters.get("max_partition_size", 128), + vector_search_beam_size=parameters.get("vector_search_beam_size", 32), + ), + **parameters, + ) diff --git a/vectordb_bench/backend/clients/cockroachdb/cockroachdb.py b/vectordb_bench/backend/clients/cockroachdb/cockroachdb.py new file mode 100644 index 000000000..dd4016d7f --- /dev/null +++ b/vectordb_bench/backend/clients/cockroachdb/cockroachdb.py @@ -0,0 +1,439 @@ +"""CockroachDB vector database client with connection pooling and retry logic.""" + +import logging +from collections.abc import Generator +from contextlib import contextmanager +from typing import Any + +import numpy as np +import psycopg +from pgvector.psycopg import register_vector +from psycopg import Connection, Cursor, sql +from psycopg_pool import ConnectionPool + +from vectordb_bench.backend.filter import Filter, FilterOp + +from ..api import VectorDB +from .config import CockroachDBIndexConfig +from .db_retry import db_retry + +log = logging.getLogger(__name__) + + +class CockroachDB(VectorDB): + """ + CockroachDB vector database client: + - Connection pooling (100+ connections for high throughput) + - Automatic retry for serialization errors (40001, 40003) + - Vector index support (C-SPANN algorithm) + - Multi-region resilience + """ + + supported_filter_types: list[FilterOp] = [ + FilterOp.NonFilter, + FilterOp.NumGE, + FilterOp.StrEqual, + ] + + def __init__( + self, + dim: int, + db_config: dict, + db_case_config: CockroachDBIndexConfig | None, + collection_name: str = "vdbbench_cockroachdb", + drop_old: bool = False, + with_scalar_labels: bool = False, + **kwargs, + ): + self.name = "CockroachDB" + self.case_config = db_case_config + self.table_name = collection_name + + # Handle both dict-style config (from to_dict()) and direct dict + if "connect_config" in db_config: + self.connect_config = db_config["connect_config"] + self.pool_size = db_config.get("pool_size", 100) + self.max_overflow = db_config.get("max_overflow", 100) + self.pool_recycle = db_config.get("pool_recycle", 3600) + else: + # Direct connection config for tests + conn_params = { + "host": db_config.get("host", "localhost"), + "port": db_config.get("port", 26257), + "dbname": db_config.get("db_name", "defaultdb"), + "user": db_config.get("user_name", "root"), + "password": db_config.get("password", ""), + } + # Add sslmode if specified, otherwise default to disable for local dev + conn_params["sslmode"] = db_config.get("sslmode", "disable") + + self.connect_config = conn_params + self.pool_size = db_config.get("pool_size", 100) + self.max_overflow = db_config.get("max_overflow", 100) + self.pool_recycle = db_config.get("pool_recycle", 3600) + + self.dim = dim + self.with_scalar_labels = with_scalar_labels + + self._index_name = f"{self.table_name}_vector_idx" + self._primary_field = "id" # UUID for distribution + self._metadata_field = "metadata_id" # BIGINT for framework compatibility + self._vector_field = "embedding" + self._scalar_label_field = "label" + + self.pool: ConnectionPool | None = None + self.conn: Connection | None = None + self.cursor: Cursor | None = None + + log.info(f"{self.name} config: {self.connect_config}, pool_size={self.pool_size}") + + # Allow manual index creation (both flags can be False) + # This is useful when CREATE INDEX times out in subprocess on multi-node clusters + if self.case_config is not None and not any( + (self.case_config.create_index_before_load, self.case_config.create_index_after_load) + ): + log.warning(f"{self.name}: Both create_index flags are False - expecting manually created index") + + # Initialize with temporary connection for setup + conn, cursor = self._create_connection(**self.connect_config) + try: + # Enable pgvector extension (in transaction) + cursor.execute("CREATE EXTENSION IF NOT EXISTS vector") + conn.commit() + cursor.close() + + # Enable vector indexes at cluster level (requires autocommit, not in transaction) + conn.autocommit = True + cursor = conn.cursor() + try: + cursor.execute("SET CLUSTER SETTING feature.vector_index.enabled = true") + except Exception as e: + # May already be enabled or permission issue, log and continue + log.warning(f"Could not enable vector indexes: {e}") + cursor.close() + + # Reset to transaction mode for remaining operations + conn.autocommit = False + cursor = conn.cursor() + + if drop_old: + if self.case_config is not None: + self._drop_index() # Use SQLAlchemy + self._drop_table(cursor, conn) + self._create_table(cursor, conn, dim) + if self.case_config is not None and self.case_config.create_index_before_load: + self._create_index() # Use SQLAlchemy + finally: + cursor.close() + conn.close() + + @staticmethod + def _create_connection(**kwargs) -> tuple[Connection, Cursor]: + """Create a single connection with pgvector support.""" + conn = psycopg.connect(**kwargs) + register_vector(conn) + conn.autocommit = False + cursor = conn.cursor() + return conn, cursor + + def _create_connection_pool(self) -> ConnectionPool: + """Create connection pool with production settings.""" + # Build connection info without 'options' parameter (not supported by psycopg_pool) + conninfo = ( + f"host={self.connect_config['host']} " + f"port={self.connect_config['port']} " + f"dbname={self.connect_config['dbname']} " + f"user={self.connect_config['user']} " + f"password={self.connect_config['password']}" + ) + + # Add sslmode if present + if "sslmode" in self.connect_config: + conninfo += f" sslmode={self.connect_config['sslmode']}" + + # Add statement timeout for long-running vector index operations + conninfo += " options='-c statement_timeout=600s'" + + return ConnectionPool( + conninfo=conninfo, + min_size=self.pool_size, + max_size=self.pool_size + self.max_overflow, + max_lifetime=self.pool_recycle, + max_idle=300, + reconnect_timeout=10.0, + configure=lambda conn: register_vector(conn), + ) + + @contextmanager + def init(self) -> Generator[None, None, None]: + """Initialize connection pool for benchmark operations.""" + self.pool = self._create_connection_pool() + + try: + with self.pool.connection() as conn: + conn.autocommit = False + self.conn = conn + self.cursor = conn.cursor() + + # Set session parameters (only if case_config is provided) + if self.case_config is not None: + session_options = self.case_config.session_param()["session_options"] + for setting in session_options: + param = setting["parameter"] + command = sql.SQL("SET {setting_name} = {val};").format( + setting_name=sql.Identifier(param["setting_name"]), + val=sql.Literal(int(param["val"])), + ) + log.debug(command.as_string(self.cursor)) + self.cursor.execute(command) + conn.commit() + + yield + finally: + if self.cursor: + self.cursor.close() + if self.conn: + self.conn.close() + if self.pool: + self.pool.close() + self.cursor = None + self.conn = None + self.pool = None + + @db_retry(max_attempts=3, initial_delay=0.5, backoff_factor=2.0) + def _drop_table(self, cursor: Cursor, conn: Connection): + """Drop table with retry logic.""" + log.info(f"{self.name} dropping table: {self.table_name}") + cursor.execute( + sql.SQL("DROP TABLE IF EXISTS {table_name} CASCADE").format( + table_name=sql.Identifier(self.table_name), + ), + ) + conn.commit() + + def _drop_index(self): + """Drop CockroachDB vector index if it exists (DDL with autocommit).""" + log.info(f"{self.name} dropping index: {self._index_name}") + conn = psycopg.connect(**self.connect_config) + conn.autocommit = True + try: + cursor = conn.cursor() + cursor.execute(f"DROP INDEX IF EXISTS {self._index_name}") + cursor.close() + finally: + conn.close() + + @db_retry(max_attempts=3, initial_delay=0.5, backoff_factor=2.0) + def _create_table(self, cursor: Cursor, conn: Connection, dim: int): + """Create table with VECTOR column.""" + log.info(f"{self.name} creating table: {self.table_name}") + + # CockroachDB best practice: Use UUID primary key to avoid hotspots in distributed deployments + # Keep metadata_id as BIGINT for framework compatibility + if self.with_scalar_labels: + cursor.execute( + sql.SQL( + """ + CREATE TABLE IF NOT EXISTS {table_name} + ({primary_field} UUID PRIMARY KEY DEFAULT gen_random_uuid(), + {metadata_field} BIGINT NOT NULL, + {vector_field} VECTOR({dim}), + {label_field} VARCHAR(64)); + """, + ).format( + table_name=sql.Identifier(self.table_name), + primary_field=sql.Identifier(self._primary_field), + metadata_field=sql.Identifier(self._metadata_field), + vector_field=sql.Identifier(self._vector_field), + label_field=sql.Identifier(self._scalar_label_field), + dim=dim, + ) + ) + else: + cursor.execute( + sql.SQL( + """ + CREATE TABLE IF NOT EXISTS {table_name} + ({primary_field} UUID PRIMARY KEY DEFAULT gen_random_uuid(), + {metadata_field} BIGINT NOT NULL, + {vector_field} VECTOR({dim})); + """ + ).format( + table_name=sql.Identifier(self.table_name), + primary_field=sql.Identifier(self._primary_field), + metadata_field=sql.Identifier(self._metadata_field), + vector_field=sql.Identifier(self._vector_field), + dim=dim, + ) + ) + + # Note: CockroachDB doesn't support SET STORAGE PLAIN (PostgreSQL-specific) + # Vector columns are handled automatically + conn.commit() + + def _create_index(self): + """Create CockroachDB vector index (DDL with autocommit).""" + log.info(f"{self.name} creating vector index: {self._index_name}") + + index_param = self.case_config.index_param() + + # Build WITH clause for index parameters + options_list = [] + for option in index_param["index_creation_with_options"]: + if option["val"] is not None: + options_list.append(f"{option['option_name']} = {option['val']}") + + with_clause = f" WITH ({', '.join(options_list)})" if options_list else "" + + # Build SQL string (DDL - no need for parameterization) + sql_str = ( + f"CREATE VECTOR INDEX IF NOT EXISTS {self._index_name} " + f"ON {self.table_name} ({self._vector_field} {index_param['metric']})" + f"{with_clause}" + ) + + log.info(f"Creating index with SQL: {sql_str}") + + # Use autocommit for DDL + conn = psycopg.connect(**self.connect_config) + conn.autocommit = True + try: + cursor = conn.cursor() + cursor.execute(sql_str) + cursor.close() + finally: + conn.close() + + def optimize(self, data_size: int | None = None): + """Post-insert optimization: create index if needed. + + Note: Uses connection pool instead of creating new connection to avoid + subprocess timeout issues in CockroachDB. + """ + log.info(f"{self.name} post-insert optimization") + if self.case_config is not None and self.case_config.create_index_after_load: + # Use existing pool connection instead of creating new one + with self.pool.connection() as conn: + register_vector(conn) + conn.autocommit = True + cursor = conn.cursor() + + try: + # Build CREATE INDEX SQL (SKIP DROP to avoid timeouts) + index_param = self.case_config.index_param() + options_list = [] + for option in index_param["index_creation_with_options"]: + if option["val"] is not None: + options_list.append(f"{option['option_name']} = {option['val']}") + + with_clause = f" WITH ({', '.join(options_list)})" if options_list else "" + sql_str = ( + f"CREATE VECTOR INDEX IF NOT EXISTS {self._index_name} " + f"ON {self.table_name} ({self._vector_field} {index_param['metric']})" + f"{with_clause}" + ) + + log.info(f"{self.name} creating vector index: {self._index_name}") + log.info(f"Index SQL: {sql_str}") + cursor.execute(sql_str) + + finally: + cursor.close() + + @db_retry(max_attempts=3, initial_delay=0.5, backoff_factor=2.0) + def insert_embeddings( + self, + embeddings: list[list[float]], + metadata: list[int], + labels_data: list[str] | None = None, + **kwargs: Any, + ) -> tuple[int, Exception | None]: + """Insert embeddings with COPY for performance.""" + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + + if self.with_scalar_labels: + assert labels_data is not None, "labels_data required when with_scalar_labels=True" + + try: + metadata_arr = np.array(metadata) + embeddings_arr = np.array(embeddings) + + # UUID primary key is auto-generated, we only insert metadata_id and embedding + with self.cursor.copy( + sql.SQL( + "COPY {table_name} ({metadata_field}, {vector_field}{label_field}) FROM STDIN (FORMAT BINARY)" + ).format( + table_name=sql.Identifier(self.table_name), + metadata_field=sql.Identifier(self._metadata_field), + vector_field=sql.Identifier(self._vector_field), + label_field=sql.SQL(f", {self._scalar_label_field}") if self.with_scalar_labels else sql.SQL(""), + ) + ) as copy: + for i, row in enumerate(metadata_arr): + if self.with_scalar_labels: + copy.set_types(["bigint", "vector", "varchar"]) + copy.write_row((row, embeddings_arr[i], labels_data[i])) + else: + copy.set_types(["bigint", "vector"]) + copy.write_row((row, embeddings_arr[i])) + + self.conn.commit() + return len(metadata), None + + except Exception as e: + log.warning(f"Failed to insert data into {self.table_name}: {e}") + return 0, e + + def prepare_filter(self, filters: Filter): + """Prepare WHERE clause for filtered queries.""" + if filters.type == FilterOp.NonFilter: + self.where_clause = "" + elif filters.type == FilterOp.NumGE: + # Filter on metadata_id, not UUID primary key + self.where_clause = f"WHERE {self._metadata_field} >= {filters.int_value}" + elif filters.type == FilterOp.StrEqual: + self.where_clause = f"WHERE {self._scalar_label_field} = '{filters.label_value}'" + else: + msg = f"Unsupported filter for CockroachDB: {filters}" + raise ValueError(msg) + + def ready_to_load(self) -> bool: + """Check if ready to load data.""" + + @db_retry(max_attempts=3, initial_delay=0.5, backoff_factor=2.0) + def search_embedding( + self, + query: list[float], + k: int = 100, + timeout: int | None = None, + **kwargs: Any, + ) -> list[int]: + """Search for k nearest neighbors using vector index.""" + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + + # Use default L2 distance if no case_config provided + if self.case_config is not None: + search_param = self.case_config.search_param() + metric_op = search_param["metric_fun_op"] + else: + metric_op = "<->" # Default to L2 distance + + q = np.asarray(query) + + # Build search query - return metadata_id for framework compatibility + search_sql = sql.SQL("SELECT {metadata_field} FROM {table_name} {where_clause} ORDER BY {vector_field}").format( + metadata_field=sql.Identifier(self._metadata_field), + table_name=sql.Identifier(self.table_name), + where_clause=sql.SQL(getattr(self, "where_clause", "")), + vector_field=sql.Identifier(self._vector_field), + ) + + # Add distance operator and limit + full_sql = search_sql + sql.SQL(" {metric_op} %s LIMIT %s").format( + metric_op=sql.SQL(metric_op), + ) + + result = self.cursor.execute(full_sql, (q, k), prepare=True, binary=True) + return [int(i[0]) for i in result.fetchall()] diff --git a/vectordb_bench/backend/clients/cockroachdb/config.py b/vectordb_bench/backend/clients/cockroachdb/config.py new file mode 100644 index 000000000..79e2ff912 --- /dev/null +++ b/vectordb_bench/backend/clients/cockroachdb/config.py @@ -0,0 +1,202 @@ +"""Configuration classes for CockroachDB vector database integration.""" + +from abc import abstractmethod +from collections.abc import Mapping, Sequence +from typing import Any, LiteralString, TypedDict + +from pydantic import BaseModel, SecretStr + +from ..api import DBCaseConfig, DBConfig, IndexType, MetricType + + +class CockroachDBConfigDict(TypedDict): + """Connection configuration for CockroachDB using psycopg.""" + + user: str + password: str + host: str + port: int + dbname: str + sslmode: str + + +class CockroachDBConfig(DBConfig): + """Main configuration for CockroachDB connection.""" + + user_name: SecretStr = "root" + password: SecretStr | None = None + host: str = "localhost" + port: int = 26257 + db_name: str = "defaultdb" + table_name: str = "vdbbench_cockroachdb" + isolation_level: str = "serializable" + pool_size: int = 100 + max_overflow: int = 100 + pool_recycle: int = 3600 + connect_timeout: int = 10 + + def to_dict(self) -> CockroachDBConfigDict: + user_str = self.user_name.get_secret_value() if isinstance(self.user_name, SecretStr) else self.user_name + pwd_str = self.password.get_secret_value() if self.password else "" + + return { + "connect_config": { + "host": self.host, + "port": self.port, + "dbname": self.db_name, + "user": user_str, + "password": pwd_str, + "sslmode": "disable", # Default for local dev; production should override + }, + "table_name": self.table_name, + "pool_size": self.pool_size, + "max_overflow": self.max_overflow, + "pool_recycle": self.pool_recycle, + "connect_timeout": self.connect_timeout, + } + + +class CockroachDBIndexParam(TypedDict): + """Index parameters for CockroachDB vector indexes.""" + + metric: str + index_creation_with_options: Sequence[dict[str, Any]] + min_partition_size: int | None + max_partition_size: int | None + build_beam_size: int | None + + +class CockroachDBSearchParam(TypedDict): + """Search parameters for CockroachDB vector queries.""" + + metric_fun_op: LiteralString + vector_search_beam_size: int | None + + +class CockroachDBSessionCommands(TypedDict): + """Session-level commands for CockroachDB.""" + + session_options: Sequence[dict[str, Any]] + + +class CockroachDBIndexConfig(BaseModel, DBCaseConfig): + """Base configuration for CockroachDB vector indexes.""" + + metric_type: MetricType | None = None + create_index_before_load: bool = False + create_index_after_load: bool = True + min_partition_size: int | None = 16 + max_partition_size: int | None = 128 + build_beam_size: int | None = 8 + vector_search_beam_size: int | None = 32 + + def parse_metric(self) -> str: + """Parse metric type to CockroachDB opclass.""" + metric_map = { + MetricType.L2: "vector_l2_ops", + MetricType.IP: "vector_ip_ops", + MetricType.COSINE: "vector_cosine_ops", + } + metric = metric_map.get(self.metric_type) + if metric is None: + msg = f"Unsupported metric type: {self.metric_type}" + raise ValueError(msg) + return metric + + def parse_metric_fun_op(self) -> LiteralString: + """Parse metric type to distance operator.""" + if self.metric_type == MetricType.L2: + return "<->" + if self.metric_type == MetricType.IP: + return "<#>" + return "<=>" + + @abstractmethod + def index_param(self) -> CockroachDBIndexParam: ... + + @abstractmethod + def search_param(self) -> CockroachDBSearchParam: ... + + @abstractmethod + def session_param(self) -> CockroachDBSessionCommands: ... + + @staticmethod + def _optionally_build_with_options(with_options: Mapping[str, Any]) -> Sequence[dict[str, Any]]: + """Build WITH options for index creation.""" + options = [] + for option_name, value in with_options.items(): + if value is not None: + options.append( + { + "option_name": option_name, + "val": str(value), + }, + ) + return options + + @staticmethod + def _optionally_build_set_options(set_mapping: Mapping[str, Any]) -> Sequence[dict[str, Any]]: + """Build SET options for session configuration.""" + session_options = [] + for setting_name, value in set_mapping.items(): + if value is not None: + session_options.append( + { + "parameter": { + "setting_name": setting_name, + "val": str(value), + }, + }, + ) + return session_options + + +class CockroachDBVectorIndexConfig(CockroachDBIndexConfig): + """ + CockroachDB Vector Index Configuration using C-SPANN algorithm. + + Available since CockroachDB v25.2. Uses hierarchical k-means clustering + for efficient approximate nearest neighbor (ANN) search. + + Tunable parameters: + - min_partition_size: Minimum vectors per partition (default: 16, range: 1-1024) + - max_partition_size: Maximum vectors per partition (default: 128, range: 4x min-4096) + - vector_search_beam_size: Partitions explored during search (default: 32) + """ + + index: IndexType = IndexType.Flat + + def index_param(self) -> CockroachDBIndexParam: + """Get index creation parameters.""" + index_parameters = { + "min_partition_size": self.min_partition_size, + "max_partition_size": self.max_partition_size, + "build_beam_size": self.build_beam_size, + } + + return { + "metric": self.parse_metric(), + "index_creation_with_options": self._optionally_build_with_options(index_parameters), + "min_partition_size": self.min_partition_size, + "max_partition_size": self.max_partition_size, + "build_beam_size": self.build_beam_size, + } + + def search_param(self) -> CockroachDBSearchParam: + """Get search parameters.""" + return { + "metric_fun_op": self.parse_metric_fun_op(), + "vector_search_beam_size": self.vector_search_beam_size, + } + + def session_param(self) -> CockroachDBSessionCommands: + """Get session parameters.""" + session_parameters = {"vector_search_beam_size": self.vector_search_beam_size} + return {"session_options": self._optionally_build_set_options(session_parameters)} + + +_cockroachdb_case_config = { + IndexType.Flat: CockroachDBVectorIndexConfig, + IndexType.HNSW: CockroachDBVectorIndexConfig, + IndexType.IVFFlat: CockroachDBVectorIndexConfig, +} diff --git a/vectordb_bench/backend/clients/cockroachdb/db_retry.py b/vectordb_bench/backend/clients/cockroachdb/db_retry.py new file mode 100644 index 000000000..3481c321b --- /dev/null +++ b/vectordb_bench/backend/clients/cockroachdb/db_retry.py @@ -0,0 +1,111 @@ +""" +Database retry utilities for handling CockroachDB transient connection failures. + +Implements retry logic for CockroachDB-specific error codes: +- 40001: Serialization failure +- 40003: Statement completion unknown (multi-region ambiguous results) +""" + +import functools +import logging +import time +from collections.abc import Callable + +from psycopg import InterfaceError, OperationalError + +log = logging.getLogger(__name__) + +TRANSIENT_ERRORS = ( + OperationalError, + InterfaceError, + Exception, # Catch all for transient detection +) + +TRANSIENT_ERROR_MESSAGES = ( + "server closed the connection unexpectedly", + "connection already closed", + "SSL connection has been closed unexpectedly", + "could not receive data from server", + "connection timed out", + "Connection refused", + "connection reset by peer", + "broken pipe", + "restart transaction", + "TransactionRetryError", + "SerializationFailure", + "StatementCompletionUnknown", + "result is ambiguous", + "failed to connect", + "no such host", + "initial connection heartbeat failed", + "sending to all replicas failed", + "40001", + "40003", +) + + +def is_transient_error(error: Exception) -> bool: + """Check if an error is transient and should be retried.""" + if not isinstance(error, TRANSIENT_ERRORS): + return False + + error_msg = str(error).lower() + return any(msg.lower() in error_msg for msg in TRANSIENT_ERROR_MESSAGES) + + +def db_retry( + max_attempts: int = 3, + initial_delay: float = 0.5, + backoff_factor: float = 2.0, + max_delay: float = 10.0, + exceptions: tuple[type[Exception], ...] = TRANSIENT_ERRORS, +): + """ + Decorator to retry database operations on transient failures. + + Uses exponential backoff with configurable parameters per CockroachDB best practices. + + Args: + max_attempts: Maximum number of retry attempts (default: 3) + initial_delay: Initial retry delay in seconds (default: 0.5) + backoff_factor: Multiplier for exponential backoff (default: 2.0) + max_delay: Maximum delay between retries (default: 10.0) + exceptions: Tuple of exception types to catch and retry + """ + + def decorator(func: Callable) -> Callable: + @functools.wraps(func) + def wrapper(*args, **kwargs): + attempt = 0 + delay = initial_delay + + while attempt < max_attempts: + try: + return func(*args, **kwargs) + + except exceptions as e: + attempt += 1 + + if not is_transient_error(e): + log.warning(f"Non-transient error in {func.__name__}: {e}") + raise + + if attempt >= max_attempts: + log.warning(f"Max retry attempts ({max_attempts}) reached for {func.__name__}") + raise + + current_delay = min(delay * (backoff_factor ** (attempt - 1)), max_delay) + + log.info( + f"Transient DB error in {func.__name__} (attempt {attempt}/{max_attempts}): {e}. " + f"Retrying in {current_delay:.2f}s..." + ) + + time.sleep(current_delay) + + msg = f"Unexpected state in retry logic for {func.__name__}" + raise RuntimeError(msg) + + return wrapper + + return decorator diff --git a/vectordb_bench/cli/vectordbbench.py b/vectordb_bench/cli/vectordbbench.py index 83dab74f6..014cfbc48 100644 --- a/vectordb_bench/cli/vectordbbench.py +++ b/vectordb_bench/cli/vectordbbench.py @@ -1,6 +1,7 @@ from ..backend.clients.alloydb.cli import AlloyDBScaNN from ..backend.clients.aws_opensearch.cli import AWSOpenSearch from ..backend.clients.clickhouse.cli import Clickhouse +from ..backend.clients.cockroachdb.cli import CockroachDB as CockroachDBCli from ..backend.clients.hologres.cli import HologresHGraph from ..backend.clients.lancedb.cli import LanceDB from ..backend.clients.mariadb.cli import MariaDBHNSW @@ -42,6 +43,7 @@ cli.add_command(OceanBaseIVF) cli.add_command(MariaDBHNSW) cli.add_command(TiDB) +cli.add_command(CockroachDBCli) cli.add_command(Clickhouse) cli.add_command(Vespa) cli.add_command(LanceDB)