Skip to content

Commit

Permalink
format code with black
Browse files Browse the repository at this point in the history
  • Loading branch information
yangbodong22011 committed Aug 1, 2023
1 parent 7f15861 commit 3bf6124
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 118 deletions.
19 changes: 14 additions & 5 deletions tair/tairsearch.py
Expand Up @@ -153,11 +153,18 @@ def tft_search(self, index: KeyT, query: str, use_cache: bool = False) -> Respon
pieces.append("use_cache")
return self.execute_command("TFT.SEARCH", *pieces)

def tft_msearch(self, index_count: int, index: Iterable[KeyT], query: str) -> ResponseT:
def tft_msearch(
self, index_count: int, index: Iterable[KeyT], query: str
) -> ResponseT:
return self.execute_command("TFT.MSEARCH", index_count, *index, query)

def tft_analyzer(self, analyzer_name: str, text: str, index: Optional[KeyT] = None,
show_time: Optional[bool] = False) -> ResponseT:
def tft_analyzer(
self,
analyzer_name: str,
text: str,
index: Optional[KeyT] = None,
show_time: Optional[bool] = False,
) -> ResponseT:
pieces: List[EncodableT] = [analyzer_name, text]
if index is not None:
pieces.append("INDEX")
Expand All @@ -167,9 +174,11 @@ def tft_analyzer(self, analyzer_name: str, text: str, index: Optional[KeyT] = No
target_nodes = None
if isinstance(self, tair.TairCluster):
if index is None:
target_nodes = 'random'
target_nodes = "random"
else:
target_nodes = self.nodes_manager.get_node_from_slot(self.keyslot(index))
target_nodes = self.nodes_manager.get_node_from_slot(
self.keyslot(index)
)
return self.execute_command("TFT.ANALYZER", *pieces, target_nodes=target_nodes)

def tft_explaincost(self, index: KeyT, query: str) -> ResponseT:
Expand Down
178 changes: 89 additions & 89 deletions tair/tairvector.py
Expand Up @@ -123,11 +123,11 @@ def __init__(self, client, name, **index_params):

# bind methods
for method in (
"tvs_del",
"tvs_hdel",
"tvs_hgetall",
"tvs_hmget",
"tvs_scan",
"tvs_del",
"tvs_hdel",
"tvs_hgetall",
"tvs_hmget",
"tvs_scan",
):
attr = getattr(TairVectorCommands, method)
if callable(attr):
Expand All @@ -150,23 +150,23 @@ def tvs_hset(self, key: str, vector: Union[VectorType, str, None] = None, **kwar
return self.client.tvs_hset(self.name, key, vector, self.is_binary, **kwargs)

def tvs_knnsearch(
self,
k: int,
vector: Union[VectorType, str],
filter_str: Optional[str] = None,
**kwargs
self,
k: int,
vector: Union[VectorType, str],
filter_str: Optional[str] = None,
**kwargs
):
"""search for the top @k approximate nearest neighbors of @vector"""
return self.client.tvs_knnsearch(
self.name, k, vector, self.is_binary, filter_str, **kwargs
)

def tvs_mknnsearch(
self,
k: int,
vectors: Sequence[VectorType],
filter_str: Optional[str] = None,
**kwargs
self,
k: int,
vectors: Sequence[VectorType],
filter_str: Optional[str] = None,
**kwargs
):
"""batch approximate nearest neighbors search for a list of vectors"""
return self.client.tvs_mknnsearch(
Expand All @@ -190,13 +190,13 @@ class TairVectorCommands(CommandsProtocol):
SCAN_INDEX_CMD = "TVS.SCANINDEX"

def tvs_create_index(
self,
name: str,
dim: int,
distance_type: str = DistanceMetric.L2,
index_type: str = IndexType.HNSW,
data_type: str = DataType.Float32,
**kwargs
self,
name: str,
dim: int,
distance_type: str = DistanceMetric.L2,
index_type: str = IndexType.HNSW,
data_type: str = DataType.Float32,
**kwargs
):
"""
create a vector
Expand Down Expand Up @@ -231,7 +231,7 @@ def tvs_del_index(self, name: str):
return self.execute_command(self.DEL_INDEX_CMD, name)

def tvs_scan_index(
self, pattern: Optional[str] = None, batch: int = 10
self, pattern: Optional[str] = None, batch: int = 10
) -> TairVectorScanResult:
"""
scan all the indices
Expand All @@ -257,12 +257,12 @@ def tvs_index(self, name: str, **index_params) -> TairVectorIndex:
SCAN_CMD = "TVS.SCAN"

def tvs_hset(
self,
index: str,
key: str,
vector: Union[VectorType, str, None] = None,
is_binary=False,
**kwargs
self,
index: str,
key: str,
vector: Union[VectorType, str, None] = None,
is_binary=False,
**kwargs
):
"""
add/update a data entry to index
Expand Down Expand Up @@ -309,13 +309,13 @@ def tvs_hmget(self, index: str, key: str, *args):
return self.execute_command(self.HMGET_CMD, index, key, *args)

def tvs_scan(
self,
index: str,
pattern: Optional[str] = None,
batch: int = 10,
filter_str: Optional[str] = None,
vector: Optional[VectorType] = None,
max_dist: Optional[float] = None,
self,
index: str,
pattern: Optional[str] = None,
batch: int = 10,
filter_str: Optional[str] = None,
vector: Optional[VectorType] = None,
max_dist: Optional[float] = None,
):
"""
scan all data entries in an index
Expand All @@ -340,14 +340,14 @@ def get_batch(c):
return TairVectorScanResult(self, get_batch)

def _tvs_scan(
self,
index: str,
cursor: int = 0,
count: Optional[int] = None,
pattern: Optional[str] = None,
filter_str: Optional[str] = None,
vector: Union[VectorType, bytes, None] = None,
max_dist: Optional[float] = None,
self,
index: str,
cursor: int = 0,
count: Optional[int] = None,
pattern: Optional[str] = None,
filter_str: Optional[str] = None,
vector: Union[VectorType, bytes, None] = None,
max_dist: Optional[float] = None,
):
args = [] if pattern is None else ["MATCH", pattern]
if count is not None:
Expand All @@ -374,13 +374,13 @@ def _tvs_scan(
MINDEXMKNNSEARCH_CMD = "TVS.MINDEXMKNNSEARCH"

def tvs_knnsearch(
self,
index: str,
k: int,
vector: Union[VectorType, str, bytes],
is_binary: bool = False,
filter_str: Optional[str] = None,
**kwargs
self,
index: str,
k: int,
vector: Union[VectorType, str, bytes],
is_binary: bool = False,
filter_str: Optional[str] = None,
**kwargs
):
"""
search for the top @k approximate nearest neighbors of @vector in an index
Expand All @@ -395,13 +395,13 @@ def tvs_knnsearch(
)

def tvs_mknnsearch(
self,
index: str,
k: int,
vectors: Sequence[VectorType],
is_binary: bool = False,
filter_str: Optional[str] = None,
**kwargs
self,
index: str,
k: int,
vectors: Sequence[VectorType],
is_binary: bool = False,
filter_str: Optional[str] = None,
**kwargs
):
"""
batch approximate nearest neighbors search for a list of vectors
Expand Down Expand Up @@ -430,13 +430,13 @@ def tvs_mknnsearch(
)

def tvs_mindexknnsearch(
self,
index: Sequence[str],
k: int,
vector: Union[VectorType, str, bytes],
is_binary: bool = False,
filter_str: Optional[str] = None,
**kwargs
self,
index: Sequence[str],
k: int,
vector: Union[VectorType, str, bytes],
is_binary: bool = False,
filter_str: Optional[str] = None,
**kwargs
):
"""
search for the top @k approximate nearest neighbors of @vector in indexs
Expand All @@ -453,13 +453,13 @@ def tvs_mindexknnsearch(
)

def tvs_mindexmknnsearch(
self,
index: Sequence[str],
k: int,
vectors: Sequence[VectorType],
is_binary: bool = False,
filter_str: Optional[str] = None,
**kwargs
self,
index: Sequence[str],
k: int,
vectors: Sequence[VectorType],
is_binary: bool = False,
filter_str: Optional[str] = None,
**kwargs
):
"""
batch approximate nearest neighbors search for a list of vectors
Expand Down Expand Up @@ -492,13 +492,13 @@ def tvs_mindexmknnsearch(
GETDISTANCE_CMD = "TVS.GETDISTANCE"

def _tvs_getdistance(
self,
index_name: str,
vector: VectorType,
keys: Iterable[str],
top_n: Optional[int] = None,
max_dist: Optional[float] = None,
filter_str: Optional[str] = None,
self,
index_name: str,
vector: VectorType,
keys: Iterable[str],
top_n: Optional[int] = None,
max_dist: Optional[float] = None,
filter_str: Optional[str] = None,
):
"""
low level interface for TVS.GETDISTANCE
Expand All @@ -520,15 +520,15 @@ def _tvs_getdistance(
)

def tvs_getdistance(
self,
index_name: str,
vector: Union[VectorType, str, bytes],
keys: Iterable[str],
batch_size: int = 100000,
parallelism: int = 1,
top_n: Optional[int] = None,
max_dist: Optional[float] = None,
filter_str: Optional[str] = None,
self,
index_name: str,
vector: Union[VectorType, str, bytes],
keys: Iterable[str],
batch_size: int = 100000,
parallelism: int = 1,
top_n: Optional[int] = None,
max_dist: Optional[float] = None,
filter_str: Optional[str] = None,
):
"""
wrapped interface for TVS.GETDISTANCE
Expand Down Expand Up @@ -562,7 +562,7 @@ def process_batch(batch):

with ThreadPoolExecutor(max_workers=parallelism) as executor:
batches = [
keys[i: i + batch_size] for i in range(0, len(keys), batch_size)
keys[i : i + batch_size] for i in range(0, len(keys), batch_size)
]

futures = [executor.submit(process_batch, batch) for batch in batches]
Expand Down
23 changes: 11 additions & 12 deletions tests/test_tairsearch.py
Expand Up @@ -50,7 +50,6 @@ def test_tft_updateindex(self, t: Tair):
assert t.tft_updateindex(index, mappings2)
t.delete(index)


def test_tft_getindex(self, t: Tair):
index = "idx_" + str(uuid.uuid4())
mappings = """
Expand Down Expand Up @@ -455,7 +454,7 @@ def test_tft_search(self, t: Tair):
result = t.tft_search(index, '{"sort":[{"price":{"order":"desc"}}]}', True)
assert json.loads(want) == json.loads(result)
result = t.tft_explaincost(index, '{"sort":[{"price":{"order":"desc"}}]}')
assert json.loads(result)['QUERY_COST']
assert json.loads(result)["QUERY_COST"]
t.delete(index)

def test_tft_msearch(self, t: Tair):
Expand All @@ -480,12 +479,8 @@ def test_tft_msearch(self, t: Tair):

assert t.tft_createindex(index1, mappings)
assert t.tft_createindex(index2, mappings)
assert t.tft_madddoc(
index1, {document1: "00001", document2: "00002"}
)
assert t.tft_madddoc(
index2, {document3: "00003", document4: "00004"}
)
assert t.tft_madddoc(index1, {document1: "00001", document2: "00002"})
assert t.tft_madddoc(index2, {document3: "00003", document4: "00004"})

want = f"""{{
"aux_info": {{"index_crc64": 5843875291690071373}},
Expand Down Expand Up @@ -520,7 +515,9 @@ def test_tft_msearch(self, t: Tair):
"total": {{ "relation": "eq", "value": 4 }}
}}
}}"""
result = t.tft_msearch(2, {index1, index2}, '{"sort":[{"_doc":{"order":"asc"}}]}')
result = t.tft_msearch(
2, {index1, index2}, '{"sort":[{"_doc":{"order":"asc"}}]}'
)
assert json.loads(want) == json.loads(result)
t.delete(index1)
t.delete(index2)
Expand All @@ -547,11 +544,13 @@ def test_tft_analyzer(self, t: Tair):
}
}
}"""
text = 'This is tair-py.'
text = "This is tair-py."

assert t.tft_createindex(index, mappings)
assert t.tft_analyzer("standard", text) == t.tft_analyzer("my_analyzer", text, index)
assert 'consuming time' in str(t.tft_analyzer("standard", text, None, True))
assert t.tft_analyzer("standard", text) == t.tft_analyzer(
"my_analyzer", text, index
)
assert "consuming time" in str(t.tft_analyzer("standard", text, None, True))

t.delete(index)

Expand Down

0 comments on commit 3bf6124

Please sign in to comment.