Skip to content

Commit

Permalink
test query document ids and check ivfpq index
Browse files Browse the repository at this point in the history
  • Loading branch information
syhao committed Mar 27, 2024
1 parent b80e9df commit a1b41aa
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 40 deletions.
2 changes: 0 additions & 2 deletions sdk/python/vearch/core/client.py
Expand Up @@ -5,8 +5,6 @@
from vearch.const import DATABASE_URI, LIST_DATABASE_URI, AUTH_KEY
from vearch.result import Result, get_result
import logging
import base64
import json

logger = logging.getLogger("vearch")

Expand Down
36 changes: 18 additions & 18 deletions sdk/python/vearch/core/space.py
Expand Up @@ -3,7 +3,7 @@
from vearch.schema.space import SpaceSchema
from vearch.result import Result, ResultStatus, get_result, UpsertResult
from vearch.const import SPACE_URI, INDEX_URI, UPSERT_DOC_URI, DELETE_DOC_URI, QUERY_DOC_URI, SEARCH_DOC_URI, \
ERR_CODE_SPACE_NOT_EXIST
ERR_CODE_SPACE_NOT_EXIST, AUTH_KEY
from vearch.exception import SpaceException, DocumentException, VearchException
from vearch.utils import CodeType, VectorInfo, compute_sign_auth, DataType
from vearch.filter import Filter
Expand All @@ -29,15 +29,15 @@ def create(self, space: SpaceSchema) -> Result:
if not self._schema:
self._schema = space
sign = compute_sign_auth(secret=self.client.token)
req = requests.request(method="POST", url=url, data=space.dict(), headers={"Authorization": sign})
req = requests.request(method="POST", url=url, data=space.dict(), headers={AUTH_KEY: sign})
resp = self.client.s.send(req)
return get_result(resp)

def drop(self) -> Result:
url_params = {"database_name": self.db_name, "space_name": self.name}
url = self.client.host + SPACE_URI % url_params
sign = compute_sign_auth(secret=self.client.token)
req = requests.request(method="DELETE", url=url, headers={"Authorization": sign})
req = requests.request(method="DELETE", url=url, headers={AUTH_KEY: sign})
resp = self.client.s.send(req)
return get_result(resp)

Expand All @@ -47,7 +47,7 @@ def exist(self) -> [bool, SpaceSchema]:
uri = SPACE_URI % url_params
url = self.client.host + str(uri)
sign = compute_sign_auth(secret=self.client.token)
resp = requests.request(method="GET", url=url, headers={"Authorization": sign})
resp = requests.request(method="GET", url=url, headers={AUTH_KEY: sign})
result = get_result(resp)
if result.code == 200:
space_schema_dict = result.text
Expand All @@ -66,7 +66,7 @@ def create_index(self, field: str, index: Index) -> Result:
req_body = {"field": field, "index": index.dict(), "database": self.db_name, "space": self.name}
sign = compute_sign_auth(secret=self.client.token)
resp = requests.request(method="POST", url=url, data=json.dumps(req_body),
headers={"Authorization": sign})
headers={AUTH_KEY: sign})
return get_result(resp)

def upsert_doc(self, data: Union[List, pd.DataFrame]) -> UpsertResult:
Expand All @@ -92,16 +92,13 @@ def upsert_doc(self, data: Union[List, pd.DataFrame]) -> UpsertResult:
for em in data:
record = {}
for i, field in enumerate(self._schema.fields):
if field.data_type == DataType.VECTOR:
record[field.name] = {"feature": em[i]}
else:
record[field.name] = em[i]
record[field.name] = em[i]
records.append(record)
req_body.update({"documents": records})
logger.debug(req_body)
sign = compute_sign_auth(secret=self.client.token)
resp = requests.request(method="POST", url=url, data=json.dumps(req_body),
headers={"Authorization": sign})
headers={AUTH_KEY: sign})
return UpsertResult.parse_upsert_result_from_response(resp)
else:
raise DocumentException(CodeType.UPSERT_DOC, "data fields not conform space schema")
Expand All @@ -125,7 +122,7 @@ def delete_doc(self, filter: Filter) -> Result:
url = self.client.host + DELETE_DOC_URI
req_body = {"database": self.db_name, "space": self.name, "filter": filter.dict()}
req = requests.request(method="POST", url=url, data=json.dumps(req_body),
headers={"Authorization": self.client.token})
headers={AUTH_KEY: self.client.token})
resp = self.client.s.send(req)
return get_result(resp)

Expand Down Expand Up @@ -195,15 +192,16 @@ def search(self, document_ids: Optional[List], vector_infos: Optional[List[Vecto
if kwargs:
req_body.update(kwargs)
req = requests.request(method="POST", url=url, data=json.dumps(req_body),
headers={"Authorization": self.client.token})
headers={AUTH_KEY: self.client.token})
resp = self.client.s.send(req)
ret = get_result(resp)
if ret.code != ResultStatus.success:
raise SpaceException(CodeType.SEARCH_DOC, ret.err_msg)
search_ret = json.dumps(ret.code)
return search_ret

def query(self, document_ids: Optional[List], filter: Optional[Filter], partition_id: Optional[str] = "",
def query(self, document_ids: Optional[List] = [], filter: Optional[Filter] = None,
partition_id: Optional[str] = "",
fields: Optional[List] = [], vector: bool = False, size: int = 50) -> List[Dict]:
"""
you can asign the document_ids in [xxx,xxx,xxx,xxx,xxx],or give the other filter condition.
Expand All @@ -219,7 +217,7 @@ def query(self, document_ids: Optional[List], filter: Optional[Filter], partitio
if (not document_ids) and (not filter):
raise SpaceException(CodeType.QUERY_DOC, "document_ids and filter can not both null")
url = self.client.host + QUERY_DOC_URI
req_body = {"database": self.db_name, "space": self.name, "vector_value": vector, "size": size}
req_body = {"db_name": self.db_name, "space_name": self.name, "vector_value": vector}
query = {"query": {}}
if document_ids:
query["query"]["document_ids"] = document_ids
Expand All @@ -230,9 +228,11 @@ def query(self, document_ids: Optional[List], filter: Optional[Filter], partitio
if filter:
query["query"]["filter"] = filter.dict()
req_body.update(query)
req = requests.request(method="POST", url=url, data=json.dumps(req_body), headers={"token": self.client.token})
resp = self.client.s.send(req)
logger.debug(url)
logger.debug(json.dumps(req_body))
sign = compute_sign_auth(secret=self.client.token)
resp = requests.request(method="POST", url=url, data=json.dumps(req_body), headers={AUTH_KEY: sign})
ret = get_result(resp)
if ret.code == ResultStatus.success:
return json.dumps(ret.content)
if ret.code == 200:
return json.dumps(ret.text)
return []
19 changes: 14 additions & 5 deletions sdk/python/vearch/example.py
Expand Up @@ -6,14 +6,15 @@
from vearch.utils import DataType, MetricType
from vearch.schema.index import IvfPQIndex, Index, ScalarIndex
import logging
from typing import List

logger = logging.getLogger("vearch")


def create_space_schema() -> SpaceSchema:
book_name = Field("book_name", DataType.STRING, desc="the name of book", index=ScalarIndex("book_name_idx"))
book_vector = Field("book_character", DataType.VECTOR,
IvfPQIndex("book_vec_idx", 10000, MetricType.Inner_product, 2048, 40), dimension=512)
IvfPQIndex("book_vec_idx", 10000, MetricType.Inner_product, 2048, 8), dimension=512)
ractor_address = Field("ractor_address", DataType.STRING, desc="the place of the book put")
space_schema = SpaceSchema("book_info", fields=[book_name, book_vector, ractor_address])
return space_schema
Expand All @@ -37,7 +38,7 @@ def create_space(vc: Vearch):
print(ret.text, ret.err_msg)


def upsert_document(vc: Vearch):
def upsert_document(vc: Vearch) -> List:
import random
ractor = ["ractor_logical", "ractor_industry", "ractor_philosophy"]
book_name_template = "abcdefghijklmnopqrstuvwxyz0123456789"
Expand All @@ -51,8 +52,15 @@ def upsert_document(vc: Vearch):
space = Space("database1", "book_info")
ret = space.upsert_doc(data)
if ret:
print(ret.get_document_ids())
pass
return ret.get_document_ids()
return []


def query_documents(ids: List):
space = Space("database1", "book_info")
ret = space.query(ids)
for doc in ret:
logger.debug(doc)


def is_database_exist(vc: Vearch):
Expand Down Expand Up @@ -93,7 +101,8 @@ def drop_database(vc: Vearch):
print(space_exist)
if not space_exist:
create_space(vc)
upsert_document(vc)
ids = upsert_document(vc)
query_documents(ids)

delete_space(vc)
drop_database(vc)
Expand Down
15 changes: 9 additions & 6 deletions sdk/python/vearch/schema/index.py
Expand Up @@ -52,7 +52,7 @@ def __init__(self, index_name: str, training_threshold: int, metric_type: str, n

def dict(self):
return {"name": self._index_name, "type": IndexType.IVFPQ,
"index_params": {
"params": {
"training_threshold": self._index_params.training_threshold,
"metric_type": self._index_params.metric_type, "ncentroids": self._index_params.ncentroids,
"nsubvector": self._index_params.nsubvector,
Expand All @@ -61,6 +61,9 @@ def dict(self):
}
}

def nsubvector(self):
return self._index_params.nsubvector


class IvfFlatIndex(Index):
def __init__(self, index_name: str, metric_type: str, ncentroids: int, **kwargs):
Expand All @@ -69,7 +72,7 @@ def __init__(self, index_name: str, metric_type: str, ncentroids: int, **kwargs)

def dict(self):
return {"name": self._index_name, "type": IndexType.IVFFLAT,
"index_params": {
"params": {
"metric_type": self._index_params.metric_type,
"ncentroids": self._index_params.ncentroids
}
Expand All @@ -87,7 +90,7 @@ def __init__(self, index_name: str, ncentroids: int, **kwargs):

def dict(self):
return {"name": self._index_name, "type": IndexType.BINARYIVF,
"index_params": {
"params": {
"ncentroids": self._index_params.ncentroids}
}

Expand All @@ -99,7 +102,7 @@ def __init__(self, index_name: str, metric_type: str, **kwargs):

def dict(self):
return {"name": self._index_name, "type": IndexType.FLAT,
"index_params": {
"params": {
"metric_type": self._index_params.metric_type
}
}
Expand All @@ -111,7 +114,7 @@ def __init__(self, index_name: str, metric_type: str, nlinks: int, efConstructio
self._index_params = IndexParams(metric_type=metric_type, nlinks=nlinks, efConstruction=efConstruction)

def dict(self):
return {"name": self._index_name, "type": IndexType.HNSW, "index_params": {
return {"name": self._index_name, "type": IndexType.HNSW, "params": {
"nlinks": self._index_params.nlinks,
"efConstruction": self._index_params.efConstruction, "metric_type": self._index_params.metric_type
}
Expand All @@ -125,7 +128,7 @@ def __int__(self, index_name: str, metric_type: str, ncentroids: int, nsubvector

def dict(self):
return {"name": self._index_name, "type": IndexType.GPU_IVFPQ,
"index_params": {
"params": {
"metric_type": self._index_params.metric_type,
"ncentroids": self._index_params.ncentroids, "nsubvector": self._index_params.nsubvector
}
Expand Down
18 changes: 10 additions & 8 deletions sdk/python/vearch/schema/space.py
@@ -1,14 +1,15 @@
from typing import List, Optional
from vearch.utils import DataType
from vearch.schema.index import BinaryIvfIndex
from vearch.schema.index import BinaryIvfIndex, IvfPQIndex
from vearch.schema.field import Field
import logging
import json

logger = logging.getLogger("vearch")


class SpaceSchema:
def __init__(self,name:str, fields: List, description: str = "",
def __init__(self, name: str, fields: List, description: str = "",
partition_num: int = 1,
replication_num: int = 3):
"""
Expand All @@ -31,8 +32,9 @@ def _check_valid(self):
if field.index:
assert field.data_type not in [DataType.NONE, DataType.UNKNOWN]
if isinstance(field.index, BinaryIvfIndex):
assert field.dim // 8 == 0, "BinaryIvfIndex vector dimention must be power of eight"

assert field.dim % 8 == 0, "BinaryIvfIndex vector dimention must be power of eight"
if isinstance(field.index, IvfPQIndex):
assert field.dim % field.index.nsubvector() == 0, "IVFPQIndex vector dimention must be power of nsubvector"

def dict(self):
space_schema = {"name": self.name, "desc": self.description, "partition_num": self.partition_num,
Expand All @@ -47,12 +49,12 @@ def dict(self):
@classmethod
def from_dict(cls, data_dict):
print(data_dict)
name=data_dict.get("space_name")
name = data_dict.get("space_name")
schema_dict = data_dict.get("schema")
logger.debug(schema_dict)
fields = [Field.from_dict(field) for field in schema_dict.get("fields")]
print(type(name))
return cls(name=name, fields=fields,
description=data_dict.get("desc", ""),
partition_num=data_dict.get("partition_num"),
replication_num=data_dict.get("replica_num"))
description=data_dict.get("desc", ""),
partition_num=data_dict.get("partition_num"),
replication_num=data_dict.get("replica_num"))
2 changes: 1 addition & 1 deletion sdk/python/vearch/utils.py
Expand Up @@ -16,7 +16,7 @@
"format": "%(asctime)s - %(levelname)s - %(filename)s[:%(lineno)d] - %(message)s",
},
"normal": {
"format": "%(asctime)s - %(levelname)s - %(message)s",
"format": "%(asctime)s - %(levelname)s - %(filename)s[:%(lineno)d] - %(message)s",
}

},
Expand Down

0 comments on commit a1b41aa

Please sign in to comment.