forked from qdrant/vector-db-benchmark
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathupload.py
71 lines (60 loc) · 1.86 KB
/
upload.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import multiprocessing as mp
import uuid
from typing import List
from opensearchpy import OpenSearch
from dataset_reader.base_reader import Record
from engine.base_client.upload import BaseUploader
from engine.clients.opensearch.config import (
OPENSEARCH_INDEX,
OPENSEARCH_PASSWORD,
OPENSEARCH_PORT,
OPENSEARCH_USER,
)
class ClosableOpenSearch(OpenSearch):
def __del__(self):
self.close()
class OpenSearchUploader(BaseUploader):
client: OpenSearch = None
upload_params = {}
@classmethod
def get_mp_start_method(cls):
return "forkserver" if "forkserver" in mp.get_all_start_methods() else "spawn"
@classmethod
def init_client(cls, host, distance, connection_params, upload_params):
init_params = {
**{
"verify_certs": False,
"request_timeout": 90,
"retry_on_timeout": True,
},
**connection_params,
}
cls.client = OpenSearch(
f"http://{host}:{OPENSEARCH_PORT}",
basic_auth=(OPENSEARCH_USER, OPENSEARCH_PASSWORD),
**init_params,
)
cls.upload_params = upload_params
@classmethod
def upload_batch(cls, batch: List[Record]):
operations = []
for record in batch:
vector_id = uuid.UUID(int=record.id).hex
operations.append({"index": {"_id": vector_id}})
operations.append({"vector": record.vector, **(record.metadata or {})})
cls.client.bulk(
index=OPENSEARCH_INDEX,
body=operations,
params={
"timeout": 300,
},
)
@classmethod
def post_upload(cls, _distance):
cls.client.indices.forcemerge(
index=OPENSEARCH_INDEX,
params={
"timeout": 300,
},
)
return {}