Skip to content

Commit 7640df0

Browse files
authored
Add SVS support (#23)
1 parent 25ca5ac commit 7640df0

File tree

3 files changed

+64
-6
lines changed

3 files changed

+64
-6
lines changed

engine/clients/redis/search.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@ def init_client(cls, host, distance, connection_params: dict, search_params: dic
3737
# 'EF_RUNTIME' is irrelevant for 'ADHOC_BF' policy
3838
if cls.hybrid_policy != "ADHOC_BF":
3939
cls.knn_conditions = "EF_RUNTIME $EF"
40-
40+
elif cls.algorithm == "SVS":
41+
cls.knn_conditions = "WS_SEARCH $WS_SEARCH"
42+
elif cls.algorithm == "SVS_TIERED":
43+
cls.knn_conditions = "WS_SEARCH $WS_SEARCH"
4144
cls.data_type = "FLOAT32"
4245
if "search_params" in cls.search_params:
4346
cls.data_type = (
@@ -95,6 +98,10 @@ def search_one(cls, vector, meta_conditions, top) -> List[Tuple[int, float]]:
9598
# 'EF_RUNTIME' is irrelevant for 'ADHOC_BF' policy
9699
if cls.hybrid_policy != "ADHOC_BF":
97100
params_dict["EF"] = cls.search_params["search_params"]["ef"]
101+
if cls.algorithm == "SVS":
102+
params_dict["WS_SEARCH"] = cls.search_params["search_params"]["WS_SEARCH"]
103+
if cls.algorithm == "SVS_TIERED":
104+
params_dict["WS_SEARCH"] = cls.search_params["search_params"]["WS_SEARCH"]
98105
results = cls._ft.search(q, query_params=params_dict)
99106

100107
return [(int(result.id), float(result.vector_score)) for result in results.docs]

engine/clients/redis/upload.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,21 +95,21 @@ def upload_batch(
9595

9696
@classmethod
9797
def post_upload(cls, _distance):
98-
if cls.algorithm != "HNSW" and cls.algorithm != "FLAT":
98+
if cls.algorithm != "HNSW" and cls.algorithm != "FLAT" and cls.algorithm != "SVS" and cls.algorithm != "SVS_TIERED":
9999
print(f"TODO: FIXME!! Avoiding calling ft.info for {cls.algorithm}...")
100100
return {}
101101
index_info = cls.client.ft().info()
102102
# redisearch / memorystore for redis
103-
if "percent_index" in index_info:
104-
percent_index = float(index_info["percent_index"])
103+
if "percent_indexed" in index_info:
104+
percent_index = float(index_info["percent_indexed"])
105105
while percent_index < 1.0:
106106
print(
107107
"waiting for index to be fully processed. current percent index: {}".format(
108108
percent_index * 100.0
109109
)
110110
)
111111
time.sleep(1)
112-
percent_index = float(cls.client.ft().info()["percent_index"])
112+
percent_index = float(cls.client.ft().info()["percent_indexed"])
113113
# memorydb
114114
if "current_lag" in index_info:
115115
current_lag = float(index_info["current_lag"])
@@ -136,7 +136,7 @@ def get_memory_usage(cls):
136136
used_memory.append(used_memory_shard)
137137
index_info = {}
138138
device_info = {}
139-
if cls.algorithm != "HNSW" and cls.algorithm != "FLAT":
139+
if cls.algorithm != "HNSW" and cls.algorithm != "FLAT" and cls.algorithm != "SVS" and cls.algorithm != "SVS_TIERED":
140140
print(f"TODO: FIXME!! Avoiding calling ft.info for {cls.algorithm}...")
141141
else:
142142
index_info = cls.client_decode.ft().info()
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import json
2+
3+
threads = [16]
4+
ws_constructs = [100]
5+
ws_search = [32, 40, 48, 64]
6+
#ws_search = [48]
7+
graph_degree = [32]
8+
#quantization = ["0", "4x4", "4x8", "8", "4"]
9+
quantization = ["8"]
10+
topKs = [10]
11+
data_types = ["FLOAT32"]
12+
13+
for algo in ["svs_tiered"]:
14+
for data_type in data_types:
15+
for ws_construct in ws_constructs:
16+
for graph_d in graph_degree:
17+
for quant in quantization:
18+
configs = []
19+
for thread in threads:
20+
config = {
21+
"name": f"svs-test-algo-{algo}-graph-{graph_d}-ws-con-{ws_construct}-quant-{quant}-threads-{thread}-dt-{data_type}",
22+
"engine": "redis",
23+
"connection_params": {},
24+
"collection_params": {
25+
"algorithm": algo,
26+
"data_type": data_type,
27+
f"{algo}_config": {"NUM_THREADS": thread, "GRAPH_DEGREE": graph_d, "WS_CONSTRUCTION": ws_construct, "QUANTIZATION": quant},
28+
},
29+
"search_params": [],
30+
"upload_params": {
31+
"parallel": 128,
32+
"data_type": data_type,
33+
"algorithm": algo,
34+
},
35+
}
36+
for client in [1, 8, 16, 32, 64, 128]:
37+
for ws_s in ws_search:
38+
for top in topKs:
39+
test_config = {
40+
"algorithm": algo,
41+
"parallel": client,
42+
"top": top,
43+
"search_params": {"WS_SEARCH": ws_s, "data_type": data_type},
44+
}
45+
config["search_params"].append(test_config)
46+
configs.append(config)
47+
48+
fname = f"svs-test-algo-{algo}-graph-{graph_d}-ws-con-{ws_construct}-quant-{quant}-threads-{thread}-dt-{data_type}.json"
49+
with open(fname, "w") as json_fd:
50+
json.dump(configs, json_fd, indent=2)
51+
print(f"Created {len(configs)} configs for {fname}.")

0 commit comments

Comments
 (0)