Skip to content

Commit

Permalink
feat: Vector search API wrapper (#39)
Browse files Browse the repository at this point in the history
* feat: Vector stores API wrapper

* test: patch tigris client in unit tests

* docs: docstrings
  • Loading branch information
adilansari committed Jun 1, 2023
1 parent 96d24c9 commit eef942f
Show file tree
Hide file tree
Showing 11 changed files with 311 additions and 27 deletions.
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ repos:
pep8-naming==0.13.3,
flake8-bugbear==23.5.9
]
- repo: https://github.com/PyCQA/docformatter
rev: v1.7.1
hooks:
- id: docformatter
additional_dependencies: [ tomli ]
args: [ --in-place, --black ]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
Expand Down
12 changes: 11 additions & 1 deletion tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


class StubRpcError(grpc.RpcError):
def __init__(self, code: str, details: Optional[str]):
def __init__(self, code: grpc.StatusCode, details: Optional[str]):
self._code = code
self._details = details

Expand All @@ -13,3 +13,13 @@ def code(self):

def details(self):
return self._details


class UnavailableRpcError(StubRpcError):
def __init__(self, details: Optional[str]):
super().__init__(grpc.StatusCode.UNAVAILABLE, details)


class NotFoundRpcError(StubRpcError):
def __init__(self, details: Optional[str]):
super().__init__(grpc.StatusCode.NOT_FOUND, details)
2 changes: 1 addition & 1 deletion tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_get_access_token_with_rpc_failure(self, channel_ready_future, grpc_auth
channel_ready_future.return_value = self.done_future
mock_grpc_auth = grpc_auth()
mock_grpc_auth.GetAccessToken.side_effect = StubRpcError(
code="Unavailable", details=""
code=grpc.StatusCode.UNAVAILABLE, details=""
)

auth_gateway = AuthGateway(self.client_config)
Expand Down
26 changes: 8 additions & 18 deletions tests/test_search_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
SearchIndexResponse,
UpdateDocumentResponse,
)
from tests import StubRpcError
from tests import UnavailableRpcError
from tigrisdb.errors import TigrisServerError
from tigrisdb.search_index import SearchIndex
from tigrisdb.types import ClientConfig, Document
Expand Down Expand Up @@ -54,9 +54,7 @@ def test_search(self, grpc_search):
def test_search_with_error(self, grpc_search):
search_index = SearchIndex(self.index_name, grpc_search(), self.client_config)
mock_grpc = grpc_search()
mock_grpc.Search.side_effect = StubRpcError(
code="Unavailable", details="operational failure"
)
mock_grpc.Search.side_effect = UnavailableRpcError("operational failure")
with self.assertRaisesRegex(TigrisServerError, "operational failure") as e:
search_index.search(SearchQuery())
self.assertIsNotNone(e)
Expand Down Expand Up @@ -88,9 +86,7 @@ def test_create_many_with_error(self, grpc_search):
docs = [{"id": 1, "name": "shoe"}, {"id": 2, "name": "jacket"}]
search_index = SearchIndex(self.index_name, grpc_search(), self.client_config)
mock_grpc = grpc_search()
mock_grpc.Create.side_effect = StubRpcError(
code="Unavailable", details="operational failure"
)
mock_grpc.Create.side_effect = UnavailableRpcError("operational failure")

with self.assertRaisesRegex(TigrisServerError, "operational failure") as e:
search_index.create_many(docs)
Expand Down Expand Up @@ -132,9 +128,7 @@ def test_delete_many(self, grpc_search):
def test_delete_many_with_error(self, grpc_search):
search_index = SearchIndex(self.index_name, grpc_search(), self.client_config)
mock_grpc = grpc_search()
mock_grpc.Delete.side_effect = StubRpcError(
code="Unavailable", details="operational failure"
)
mock_grpc.Delete.side_effect = UnavailableRpcError("operational failure")

with self.assertRaisesRegex(TigrisServerError, "operational failure") as e:
search_index.delete_many(["id"])
Expand Down Expand Up @@ -178,8 +172,8 @@ def test_create_or_replace_many_with_error(self, grpc_search):
docs = [{"id": 1, "name": "shoe"}, {"id": 2, "name": "jacket"}]
search_index = SearchIndex(self.index_name, grpc_search(), self.client_config)
mock_grpc = grpc_search()
mock_grpc.CreateOrReplace.side_effect = StubRpcError(
code="Unavailable", details="operational failure"
mock_grpc.CreateOrReplace.side_effect = UnavailableRpcError(
"operational failure"
)

with self.assertRaisesRegex(TigrisServerError, "operational failure") as e:
Expand Down Expand Up @@ -242,9 +236,7 @@ def test_get_many(self, grpc_search):
def test_get_many_with_error(self, grpc_search):
search_index = SearchIndex(self.index_name, grpc_search(), self.client_config)
mock_grpc = grpc_search()
mock_grpc.Get.side_effect = StubRpcError(
code="Unavailable", details="operational failure"
)
mock_grpc.Get.side_effect = UnavailableRpcError("operational failure")

with self.assertRaisesRegex(TigrisServerError, "operational failure") as e:
search_index.get_many(["id"])
Expand Down Expand Up @@ -288,9 +280,7 @@ def test_update_many_with_error(self, grpc_search):
docs = [{"id": 1, "name": "shoe"}, {"id": 2, "name": "jacket"}]
search_index = SearchIndex(self.index_name, grpc_search(), self.client_config)
mock_grpc = grpc_search()
mock_grpc.Update.side_effect = StubRpcError(
code="Unavailable", details="operational failure"
)
mock_grpc.Update.side_effect = UnavailableRpcError("operational failure")

with self.assertRaisesRegex(TigrisServerError, "operational failure") as e:
search_index.update_many(docs)
Expand Down
111 changes: 111 additions & 0 deletions tests/test_vector_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from unittest import TestCase
from unittest.mock import MagicMock, Mock, call, patch

from tests import NotFoundRpcError
from tigrisdb.errors import TigrisServerError
from tigrisdb.search import Search
from tigrisdb.search_index import SearchIndex
from tigrisdb.types.search import (
DocMeta,
DocStatus,
IndexedDoc,
Query,
Result,
TextMatchInfo,
VectorField,
)
from tigrisdb.types.vector import Document
from tigrisdb.vector_store import VectorStore

doc: Document = {
"text": "Hello world vector embed",
"embeddings": [1.2, 2.3, 4.5],
"metadata": {"category": "shoes"},
}


class VectorStoreTest(TestCase):
def setUp(self) -> None:
self.mock_index = Mock(spec=SearchIndex)
self.mock_client = Mock(spec=Search)
with patch("tigrisdb.client.TigrisClient.__new__") as mock_tigris:
instance = MagicMock()
mock_tigris.return_value = instance
instance.get_search.return_value = self.mock_client
self.mock_client.get_index.return_value = self.mock_index
self.store = VectorStore("my_vectors")

def test_add_documents_when_index_not_found(self):
# throw error on first call and succeed on second
self.mock_index.create_many.side_effect = [
TigrisServerError("", NotFoundRpcError("search index not found")),
[DocStatus(id="1")],
]

resp = self.store.add_documents([doc])
self.assertEqual([DocStatus(id="1")], resp)
self.assertEqual(self.mock_index.create_many.call_count, 2)
self.mock_index.create_many.assert_has_calls([call([doc]), call([doc])])

# create_or_update_index gets called once
expected_schema = {
"title": self.store.name,
"additionalProperties": False,
"type": "object",
"properties": {
"id": {"type": "string"},
"text": {"type": "string"},
"metadata": {"type": "object"},
"embeddings": {"type": "array", "format": "vector", "dimensions": 3},
},
}

self.mock_client.create_or_update_index.assert_called_once_with(
name=self.store.name, schema=expected_schema
)

def test_add_documents_when_index_exists(self):
self.mock_index.create_many.return_value = [DocStatus(id="1")]
resp = self.store.add_documents([doc])
self.assertEqual([DocStatus(id="1")], resp)

# no calls to create_or_update_index
self.mock_client.assert_not_called()

def test_add_documents_when_project_not_found(self):
self.mock_index.create_many.side_effect = [
TigrisServerError("", NotFoundRpcError("project not found")),
[DocStatus(id="1")],
]
with self.assertRaisesRegex(TigrisServerError, "project not found"):
self.store.add_documents([doc])
self.mock_index.create_many.assert_called_once_with([doc])

def test_delete_documents(self):
self.store.delete_documents(["id"])
self.mock_index.delete_many.assert_called_once_with(["id"])

def test_get_documents(self):
self.store.get_documents(["id"])
self.mock_index.get_many.assert_called_once_with(["id"])

def test_similarity_search(self):
self.mock_index.search.return_value = Result(
hits=[
IndexedDoc(
doc=doc,
meta=DocMeta(text_match=TextMatchInfo(vector_distance=0.1234)),
)
]
)
resp = self.store.similarity_search([1, 1, 1], 12)
self.assertEqual(1, len(resp))
self.assertEqual(doc, resp[0].doc)
self.assertEqual(0.1234, resp[0].score)

self.mock_index.search.assert_called_once_with(
query=Query(
vector_query=VectorField(field="embeddings", vector=[1, 1, 1]),
hits_per_page=12,
)
)
12 changes: 11 additions & 1 deletion tigrisdb/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,24 @@ def __init__(self, config: Optional[ClientConfig]):
config = ClientConfig()
self.__config = config
if not config.server_url:
config.server_url = TigrisClient.__LOCAL_SERVER
config.server_url = os.getenv("TIGRIS_URI", TigrisClient.__LOCAL_SERVER)
if config.server_url.startswith("https://"):
config.server_url = config.server_url.replace("https://", "")
if config.server_url.startswith("http://"):
config.server_url = config.server_url.replace("http://", "")
if ":" not in config.server_url:
config.server_url = f"{config.server_url}:443"

# initialize rest of config
if not config.project_name:
config.project_name = os.getenv("TIGRIS_PROJECT")
if not config.client_id:
config.client_id = os.getenv("TIGRIS_CLIENT_ID")
if not config.client_secret:
config.client_secret = os.getenv("TIGRIS_CLIENT_SECRET")
if not config.branch:
config.branch = os.getenv("TIGRIS_DB_BRANCH", "")

is_local_dev = any(
map(
lambda k: k in config.server_url,
Expand Down
17 changes: 13 additions & 4 deletions tigrisdb/errors.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import cast

import grpc


class TigrisException(Exception):
"""
Base class for all TigrisExceptions
"""
"""Base class for all TigrisExceptions."""

msg: str

Expand All @@ -17,4 +17,13 @@ def __init__(self, msg: str, **kwargs):
# TODO: make this typesafe
class TigrisServerError(TigrisException):
def __init__(self, msg: str, e: grpc.RpcError):
super(TigrisServerError, self).__init__(msg, code=e.code(), details=e.details())
if isinstance(e.code(), grpc.StatusCode):
self.code = cast(grpc.StatusCode, e.code())
else:
self.code = grpc.StatusCode.UNKNOWN

self.details = e.details()
super(TigrisServerError, self).__init__(
msg, code=self.code.name, details=self.details
)
self.__suppress_context__ = True
2 changes: 1 addition & 1 deletion tigrisdb/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

@dataclass
class ClientConfig:
project_name: str
project_name: str = ""
client_id: Optional[str] = None
client_secret: Optional[str] = None
branch: str = ""
Expand Down
1 change: 0 additions & 1 deletion tigrisdb/types/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def query(self):
return {self.field: self.vector}


# TODO: add filter, collation
@dataclass
class Query:
q: str = ""
Expand Down
24 changes: 24 additions & 0 deletions tigrisdb/types/vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from dataclasses import InitVar, dataclass
from typing import Dict, List, TypedDict

from tigrisdb.types.search import IndexedDoc, dataclass_default_proto_field


class Document(TypedDict, total=False):
id: str
text: str
embeddings: List[float]
metadata: Dict


@dataclass
class DocWithScore:
doc: Document = None
score: float = 0.0
_h: InitVar[IndexedDoc] = dataclass_default_proto_field

def __post_init__(self, _h: IndexedDoc):
if _h and _h.doc:
self.doc = _h.doc
if _h and _h.meta:
self.score = _h.meta.text_match.vector_distance
Loading

0 comments on commit eef942f

Please sign in to comment.