Skip to content

Commit

Permalink
feat: search client (#18)
Browse files Browse the repository at this point in the history
* ci: added codecov

* feat: init search index
  • Loading branch information
adilansari committed May 16, 2023
1 parent 8f82e24 commit 8bd948b
Show file tree
Hide file tree
Showing 15 changed files with 368 additions and 25 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ jobs:
source $VENV
poetry run coverage run -m unittest discover -s tests -p "test_*.py"
poetry run coverage report -m
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v3

build:
needs: lint
Expand Down
Empty file added codecov.yml
Empty file.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ optional = true

[tool.poetry.group.dev.dependencies]
pre-commit = "^3.3.0"
coverage = {version = "^7.2.5", extras = ["toml"]}
coverage = { version = "^7.2.5", extras = ["toml"] }

[tool.poetry.scripts]
make = "scripts.proto:main"
Expand All @@ -35,7 +35,7 @@ make = "scripts.proto:main"
source = ["tigrisdb"]

[tool.coverage.report]
fail_under = 35
fail_under = 40

[build-system]
requires = ["poetry-core"]
Expand Down
2 changes: 1 addition & 1 deletion scripts/proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def generate():
if pd_file.endswith(".proto"):
proto_sources.append(os.path.join(pd_path, pd_file))

for pf in ["api.proto", "search.proto", "auth.proto"]:
for pf in ["api.proto", "search.proto", "auth.proto", "observability.proto"]:
pf_path = os.path.join(TIGRIS_PROTO_DIR, pf)
proto_sources.append(pf_path)

Expand Down
12 changes: 11 additions & 1 deletion tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@

import grpc

from api.generated.server.v1.auth_pb2 import GetAccessTokenResponse
from api.generated.server.v1.auth_pb2 import (
CLIENT_CREDENTIALS,
GetAccessTokenRequest,
GetAccessTokenResponse,
)
from tests import StubRpcError
from tigrisdb.auth import AuthGateway
from tigrisdb.errors import TigrisServerError
Expand Down Expand Up @@ -35,9 +39,15 @@ def test_get_access_token_with_valid_token_refresh_window(
actual_token = auth_gateway.get_access_token()
self.assertEqual(expected_token, actual_token)
next_refresh = auth_gateway.__getattribute__("_AuthGateway__next_refresh_time")

# refresh time is within 11 minutes of expiration time
self.assertLessEqual(expiration_time - next_refresh, 660)

# request validation
mock_grpc_auth.GetAccessToken.assert_called_once_with(
GetAccessTokenRequest(grant_type=CLIENT_CREDENTIALS)
)

def test_get_access_token_with_rpc_failure(self, channel_ready_future, grpc_auth):
channel_ready_future.return_value = self.done_future
mock_grpc_auth = grpc_auth()
Expand Down
66 changes: 66 additions & 0 deletions tests/test_search_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from unittest import TestCase
from unittest.mock import patch

from api.generated.server.v1.observability_pb2 import Error as ProtoError
from api.generated.server.v1.search_pb2 import DocStatus, UpdateDocumentResponse
from tests import StubRpcError
from tigrisdb.errors import TigrisServerError
from tigrisdb.search_index import SearchIndex
from tigrisdb.types import ClientConfig, Document
from tigrisdb.utils import bytes_to_dict


@patch("api.generated.server.v1.search_pb2_grpc.SearchStub")
class SearchIndexTest(TestCase):
def setUp(self) -> None:
self.client_config = ClientConfig(
server_url="localhost:5000", project_name="db1"
)
self.index_name = "catalog"

def test_create_one(self, grpc_search):
doc: Document = {"item_id": 1, "name": "shoe", "brand": "adidas"}
with patch.object(
SearchIndex, "create_many", return_value="some_str"
) as mock_create_many:
search_index = SearchIndex(
self.index_name, grpc_search(), self.client_config
)
search_index.create_one(doc)

mock_create_many.assert_called_once_with([doc])

def test_update_many(self, grpc_search):
docs = [{"id": 1, "name": "shoe"}, {"id": 2, "name": "jacket"}]
search_index = SearchIndex(self.index_name, grpc_search(), self.client_config)
mock_grpc = grpc_search()
mock_grpc.Update.return_value = UpdateDocumentResponse(
status=[
DocStatus(id="1"),
DocStatus(id="2", error=ProtoError(message="conflict")),
]
)

resp = search_index.update_many(docs)
self.assertEqual(resp[0].id, "1")
self.assertIsNone(resp[0].error)
self.assertEqual(resp[1].id, "2")
self.assertRegex(resp[1].error.msg, "conflict")

mock_grpc.Update.assert_called_once()
called_with = mock_grpc.Update.call_args.args[0]
self.assertEqual(called_with.project, search_index.project)
self.assertEqual(called_with.index, search_index.name)
self.assertEqual(list(map(bytes_to_dict, called_with.documents)), docs)

def test_update_many_with_error(self, grpc_search):
docs = [{"id": 1, "name": "shoe"}, {"id": 2, "name": "jacket"}]
search_index = SearchIndex(self.index_name, grpc_search(), self.client_config)
mock_grpc = grpc_search()
mock_grpc.Update.side_effect = StubRpcError(
code="Unavailable", details="operational failure"
)

with self.assertRaisesRegex(TigrisServerError, "operational failure") as e:
search_index.update_many(docs)
self.assertIsNotNone(e)
15 changes: 11 additions & 4 deletions tigrisdb/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@

import grpc

from api.generated.server.v1 import api_pb2_grpc as tigris_grpc
from api.generated.server.v1.api_pb2_grpc import TigrisStub
from api.generated.server.v1.search_pb2_grpc import SearchStub
from tigrisdb.auth import AuthGateway
from tigrisdb.database import Database
from tigrisdb.errors import TigrisException
from tigrisdb.search import Search
from tigrisdb.types import ClientConfig


class TigrisClient(object):
__LOCAL_SERVER = "localhost:8081"

__tigris_stub: tigris_grpc.TigrisStub
__tigris_client: TigrisStub
__search_client: SearchStub
__config: ClientConfig

def __init__(self, config: Optional[ClientConfig]):
Expand Down Expand Up @@ -54,7 +57,11 @@ def __init__(self, config: Optional[ClientConfig]):
except grpc.FutureTimeoutError:
raise TigrisException(f"Connection timed out {config.server_url}")

self.__tigris_stub = tigris_grpc.TigrisStub(channel)
self.__tigris_client = TigrisStub(channel)
self.__search_client = SearchStub(channel)

def get_db(self):
return Database(self.__tigris_stub, self.__config)
return Database(self.__tigris_client, self.__config)

def get_search(self):
return Search(self.__search_client, self.__config)
17 changes: 10 additions & 7 deletions tigrisdb/collection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

import grpc

from api.generated.server.v1.api_pb2 import InsertRequest, InsertResponse, ReadRequest
Expand Down Expand Up @@ -31,7 +33,7 @@ def branch(self):
def name(self):
return self.__name

def insert_many(self, docs: list[Document]) -> bool:
def insert_many(self, docs: List[Document]) -> bool:
doc_bytes = map(dict_to_bytes, docs)
req = InsertRequest(
project=self.project,
Expand All @@ -41,14 +43,15 @@ def insert_many(self, docs: list[Document]) -> bool:
)
try:
resp: InsertResponse = self.__client.Insert(req)
if resp.status == "inserted":
return True
else:
raise TigrisException(f"failed to insert docs: {resp.status}")
except grpc.RpcError as e:
raise TigrisServerError("failed to insert documents", e)

def find_many(self) -> list[Document]:
if resp.status == "inserted":
return True
else:
raise TigrisException(f"failed to insert docs: {resp.status}")

def find_many(self) -> List[Document]:
req = ReadRequest(
project=self.project,
branch=self.branch,
Expand All @@ -60,7 +63,7 @@ def find_many(self) -> list[Document]:
except grpc.RpcError as e:
raise TigrisServerError("failed to read documents", e)

docs: list[Document] = []
docs: List[Document] = []
for r in doc_iterator:
docs.append(bytes_to_dict(r.data))

Expand Down
15 changes: 7 additions & 8 deletions tigrisdb/database.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import json

import grpc

from api.generated.server.v1.api_pb2 import (
Expand All @@ -12,6 +10,7 @@
from tigrisdb.collection import Collection
from tigrisdb.errors import TigrisException, TigrisServerError
from tigrisdb.types import ClientConfig
from tigrisdb.utils import schema_to_bytes


class Database:
Expand All @@ -31,25 +30,25 @@ def branch(self):
return self.__config.branch

def create_or_update_collection(self, name: str, schema: dict) -> Collection:
schema_str = json.dumps(schema)
req = CreateOrUpdateCollectionRequest(
project=self.project,
branch=self.branch,
collection=name,
schema=schema_str.encode(),
schema=schema_to_bytes(schema),
only_create=False,
)
try:
resp: CreateOrUpdateCollectionResponse = (
self.__client.CreateOrUpdateCollection(req)
)
if resp.status == "created":
return Collection(name, self.__client, self.__config)
else:
raise TigrisException(f"failed to create collection: {resp.message}")
except grpc.RpcError as e:
raise TigrisServerError("failed to create collection", e)

if resp.status == "created":
return Collection(name, self.__client, self.__config)
else:
raise TigrisException(f"failed to create collection: {resp.message}")

def drop_collection(self, name: str) -> bool:
req = DropCollectionRequest(
project=self.project,
Expand Down
1 change: 1 addition & 0 deletions tigrisdb/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(self, msg: str, **kwargs):
super(TigrisException, self).__init__(kwargs)


# TODO: make this typesafe
class TigrisServerError(TigrisException):
def __init__(self, msg: str, e: grpc.RpcError):
super(TigrisServerError, self).__init__(msg, code=e.code(), details=e.details())
51 changes: 51 additions & 0 deletions tigrisdb/search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import grpc

from api.generated.server.v1.search_pb2 import (
CreateOrUpdateIndexRequest,
CreateOrUpdateIndexResponse,
DeleteIndexRequest,
DeleteIndexResponse,
)
from api.generated.server.v1.search_pb2_grpc import SearchStub
from tigrisdb.errors import TigrisException, TigrisServerError
from tigrisdb.search_index import SearchIndex
from tigrisdb.types import ClientConfig
from tigrisdb.utils import schema_to_bytes


class Search:
__client: SearchStub
__config: ClientConfig

def __init__(self, client: SearchStub, config: ClientConfig):
self.__client = client
self.__config = config

@property
def project(self):
return self.__config.project_name

def create_or_update_index(self, name: str, schema: dict) -> SearchIndex:
req = CreateOrUpdateIndexRequest(
project=self.project, name=name, schema=schema_to_bytes(schema)
)
try:
resp: CreateOrUpdateIndexResponse = self.__client.CreateOrUpdateIndex(req)
except grpc.RpcError as e:
raise TigrisServerError("failed to create search index", e)

if resp.status == "created":
return SearchIndex(
index_name=name, client=self.__client, config=self.__config
)

raise TigrisException(f"Invalid response to create search index: {resp.status}")

def delete_index(self, name: str) -> bool:
req = DeleteIndexRequest(name=name, project=self.project)
try:
resp: DeleteIndexResponse = self.__client.DeleteIndex(req)
except grpc.RpcError as e:
raise TigrisServerError("failed to delete search index", e)

return resp.status == "deleted"
Loading

0 comments on commit 8bd948b

Please sign in to comment.