Skip to content

Commit d04294b

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 45e4c9c commit d04294b

File tree

21 files changed

+65
-49
lines changed

21 files changed

+65
-49
lines changed

benchmark/dataset.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import os
22
import shutil
3-
from dataclasses import dataclass
4-
from typing import Optional
53
import tarfile
6-
74
import urllib.request
5+
from dataclasses import dataclass
6+
from typing import Optional
87

98
from benchmark import DATASETS_DIR
109
from dataset_reader.ann_h5_reader import AnnH5Reader

dataset_reader/ann_h5_reader.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
from typing import Iterator, Iterable, Any
1+
from typing import Any, Iterable, Iterator
22

33
import h5py
44
import numpy as np
55

66
from benchmark import DATASETS_DIR
7-
from dataset_reader.base_reader import BaseReader, Record, Query
7+
from dataset_reader.base_reader import BaseReader, Query, Record
88

99

1010
class AnnH5Reader(BaseReader):
@@ -54,5 +54,3 @@ def read_data(self) -> Iterator[Record]:
5454

5555
query = next(AnnH5Reader(test_path).read_queries())
5656
print(query)
57-
58-

dataset_reader/json_reader.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1+
import json
12
from pathlib import Path
23
from typing import Iterator, List, Optional
34

4-
import json
5-
65
import numpy as np
76

8-
from dataset_reader.base_reader import BaseReader, Record, Query
9-
7+
from dataset_reader.base_reader import BaseReader, Query, Record
108

119
VECTORS_FILE = "vectors.jsonl"
1210
PAYLOADS_FILE = "payloads.jsonl"

engine/base_client/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1+
from engine.base_client.client import BaseClient
12
from engine.base_client.configure import BaseConfigurator
2-
from engine.base_client.upload import BaseUploader
33
from engine.base_client.search import BaseSearcher
4-
from engine.base_client.client import BaseClient
4+
from engine.base_client.upload import BaseUploader

engine/base_client/client.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from datetime import datetime
33
from typing import List
44

5-
from benchmark.dataset import Dataset
65
from benchmark import ROOT_DIR
6+
from benchmark.dataset import Dataset
77
from engine.base_client.configure import BaseConfigurator
88
from engine.base_client.search import BaseSearcher
99
from engine.base_client.upload import BaseUploader
@@ -48,10 +48,11 @@ def save_upload_results(self, dataset_name: str, results: dict):
4848
def run_experiment(self, dataset: Dataset):
4949
print("Experiment stage: Configure")
5050
execution_params = self.configurator.configure(
51-
distance=dataset.config.distance, vector_size=dataset.config.vector_size,
51+
distance=dataset.config.distance,
52+
vector_size=dataset.config.vector_size,
5253
)
5354

54-
reader = dataset.get_reader(execution_params.get('normalize', False))
55+
reader = dataset.get_reader(execution_params.get("normalize", False))
5556
print("Experiment stage: Upload")
5657
upload_stats = self.uploader.upload(reader.read_data())
5758
self.save_upload_results(dataset.config.name, upload_stats)

engine/base_client/search.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
import time
21
import functools
2+
import time
33
from multiprocessing import get_context
4-
from typing import List, Tuple, Iterable, Optional
4+
from typing import Iterable, List, Optional, Tuple
55

66
import numpy as np
77

88
from dataset_reader.base_reader import Query
99

10-
1110
DEFAULT_TOP = 10
1211

1312

@@ -50,7 +49,8 @@ def _search_one(cls, query, top: Optional[int] = None):
5049
return precision, end - start
5150

5251
def search_all(
53-
self, queries: Iterable[Query],
52+
self,
53+
queries: Iterable[Query],
5454
):
5555
start = time.perf_counter()
5656
parallel = self.search_params.pop("parallel", 1)

engine/base_client/upload.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import time
22
from multiprocessing import get_context
3-
from typing import List, Optional, Iterable
3+
from typing import Iterable, List, Optional
44

55
import tqdm
66

@@ -21,7 +21,10 @@ def __init__(self, host, connection_params, upload_params):
2121
def init_client(cls, host, connection_params: dict, upload_params: dict):
2222
raise NotImplementedError()
2323

24-
def upload(self, records: Iterable[Record],) -> dict:
24+
def upload(
25+
self,
26+
records: Iterable[Record],
27+
) -> dict:
2528
latencies = []
2629
start = time.perf_counter()
2730
parallel = self.upload_params.pop("parallel", 1)

engine/clients/client_factory.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,13 @@
77
BaseSearcher,
88
BaseUploader,
99
)
10-
from engine.clients.qdrant import QdrantConfigurator, QdrantUploader, QdrantSearcher
10+
from engine.clients.milvus import MilvusConfigurator, MilvusSearcher, MilvusUploader
11+
from engine.clients.qdrant import QdrantConfigurator, QdrantSearcher, QdrantUploader
1112
from engine.clients.weaviate import (
1213
WeaviateConfigurator,
13-
WeaviateUploader,
1414
WeaviateSearcher,
15+
WeaviateUploader,
1516
)
16-
from engine.clients.milvus import MilvusConfigurator, MilvusUploader, MilvusSearcher
17-
1817

1918
ENGINE_CONFIGURATORS = {
2019
"qdrant": QdrantConfigurator,

engine/clients/milvus/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from engine.clients.milvus.configure import MilvusConfigurator
2-
from engine.clients.milvus.upload import MilvusUploader
32
from engine.clients.milvus.search import MilvusSearcher
3+
from engine.clients.milvus.upload import MilvusUploader

engine/clients/milvus/configure.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
import pymilvus.client.exceptions as milvus_exceptions
2+
from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections
23
from pymilvus.orm import utility
34

45
from engine.base_client.configure import BaseConfigurator
56
from engine.base_client.distances import Distance
67
from engine.clients.milvus.config import (
78
MILVUS_COLLECTION_NAME,
8-
MILVUS_DEFAULT_PORT,
99
MILVUS_DEFAULT_ALIAS,
10+
MILVUS_DEFAULT_PORT,
1011
)
1112

12-
from pymilvus import connections, DataType, FieldSchema, CollectionSchema, Collection
13-
1413

1514
class MilvusConfigurator(BaseConfigurator):
1615
DISTANCE_MAPPING = {
@@ -39,18 +38,29 @@ def clean(self):
3938
pass
4039

4140
def recreate(
42-
self, distance, vector_size, collection_params,
41+
self,
42+
distance,
43+
vector_size,
44+
collection_params,
4345
):
44-
idx = FieldSchema(name="id", dtype=DataType.INT64, is_primary=True,)
46+
idx = FieldSchema(
47+
name="id",
48+
dtype=DataType.INT64,
49+
is_primary=True,
50+
)
4551
vector = FieldSchema(
46-
name="vector", dtype=DataType.FLOAT_VECTOR, dim=vector_size,
52+
name="vector",
53+
dtype=DataType.FLOAT_VECTOR,
54+
dim=vector_size,
4755
)
4856
schema = CollectionSchema(
4957
fields=[idx, vector], description=MILVUS_COLLECTION_NAME
5058
)
5159

5260
collection = Collection(
53-
name=MILVUS_COLLECTION_NAME, schema=schema, using=MILVUS_DEFAULT_ALIAS,
61+
name=MILVUS_COLLECTION_NAME,
62+
schema=schema,
63+
using=MILVUS_DEFAULT_ALIAS,
5464
)
5565

5666
for index in collection.indexes:

0 commit comments

Comments
 (0)