diff --git a/vectordb_bench/backend/clients/milvus/cli.py b/vectordb_bench/backend/clients/milvus/cli.py index 303eec5f9..1bec8ebe5 100644 --- a/vectordb_bench/backend/clients/milvus/cli.py +++ b/vectordb_bench/backend/clients/milvus/cli.py @@ -214,6 +214,36 @@ def MilvusGPUBruteForce(**parameters: Unpack[MilvusGPUBruteForceTypedDict]): **parameters, ) +class MilvusGPUBruteForceTypedDict(CommonTypedDict, MilvusTypedDict): + metric_type: Annotated[ + str, + click.option("--metric-type", type=str, required=True, help="Metric type for brute force search"), + ] + limit: Annotated[ + int, + click.option("--limit", type=int, required=True, help="Top-k limit for search"), + ] + +@cli.command() +@click_parameter_decorators_from_typed_dict(MilvusGPUBruteForceTypedDict) +def MilvusGPUBruteForce(**parameters: Unpack[MilvusGPUBruteForceTypedDict]): + from .config import GPUBruteForceConfig, MilvusConfig + + run( + db=DBTYPE, + db_config=MilvusConfig( + db_label=parameters["db_label"], + uri=SecretStr(parameters["uri"]), + user=parameters["user_name"], + password=SecretStr(parameters["password"]), + ), + db_case_config=GPUBruteForceConfig( + metric_type=parameters["metric_type"], + limit=parameters["limit"], # top-k for search + ), + **parameters, + ) + class MilvusGPUIVFPQTypedDict( CommonTypedDict,