forked from qdrant/vector-db-benchmark
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathconfigure.py
108 lines (98 loc) · 3.2 KB
/
configure.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import redis
from redis import Redis, RedisCluster
from redis.commands.search.field import (
GeoField,
NumericField,
TagField,
TextField,
VectorField,
)
from benchmark.dataset import Dataset
from engine.base_client.configure import BaseConfigurator
from engine.base_client.distances import Distance
from engine.clients.redis.config import (
REDIS_AUTH,
REDIS_CLUSTER,
REDIS_PORT,
REDIS_USER,
)
class RedisConfigurator(BaseConfigurator):
DISTANCE_MAPPING = {
Distance.L2: "L2",
Distance.COSINE: "COSINE",
Distance.DOT: "IP",
}
FIELD_MAPPING = {
"int": NumericField,
"keyword": TagField,
"text": TextField,
"float": NumericField,
"geo": GeoField,
}
def __init__(self, host, collection_params: dict, connection_params: dict):
super().__init__(host, collection_params, connection_params)
redis_constructor = RedisCluster if REDIS_CLUSTER else Redis
self.is_cluster = REDIS_CLUSTER
self.client = redis_constructor(
host=host, port=REDIS_PORT, password=REDIS_AUTH, username=REDIS_USER
)
def clean(self):
conns = [self.client]
if self.is_cluster:
conns = [
self.client.get_redis_connection(node)
for node in self.client.get_primaries()
]
for conn in conns:
search_namespace = conn.ft()
try:
search_namespace.dropindex(delete_documents=True)
except redis.ResponseError as e:
if "Unknown Index name" not in str(e):
print(e)
def recreate(self, dataset: Dataset, collection_params):
self.clean()
payload_fields = [
self.FIELD_MAPPING[field_type](
name=field_name,
sortable=True,
)
for field_name, field_type in dataset.config.schema.items()
if field_type != "keyword"
]
payload_fields += [
TagField(
name=field_name,
separator=";",
sortable=True,
)
for field_name, field_type in dataset.config.schema.items()
if field_type == "keyword"
]
index_fields = [
VectorField(
name="vector",
algorithm="HNSW",
attributes={
"TYPE": "FLOAT32",
"DIM": dataset.config.vector_size,
"DISTANCE_METRIC": self.DISTANCE_MAPPING[dataset.config.distance],
**self.collection_params.get("hnsw_config", {}),
},
)
] + payload_fields
conns = [self.client]
if self.is_cluster:
conns = [
self.client.get_redis_connection(node)
for node in self.client.get_primaries()
]
for conn in conns:
search_namespace = conn.ft()
try:
search_namespace.create_index(fields=index_fields)
except redis.ResponseError as e:
if "Index already exists" not in str(e):
raise e
if __name__ == "__main__":
pass