Skip to content

Commit

Permalink
feat: langchain compatibility and simpler client instantiation (#46)
Browse files Browse the repository at this point in the history
* fix: Downgrade grpc and proto dependencies to be langchain compatible

* feat: change VectorStore stub to be dependent on Search Client

* feat: an option to create a client using dictionary

* chore: shorten import paths

* refactor: config merge method to override config using dictionary keys

* fix: Downgrade protobuf usage to 3.19.6

* ading api

* resetting versions to latest

* stepping down proto requirements

* fixed proto ts parsing with tz info

* open ended protobuf

* removing generated API files
  • Loading branch information
adilansari committed Jun 4, 2023
1 parent 0b76b6b commit 667d484
Show file tree
Hide file tree
Showing 19 changed files with 387 additions and 149 deletions.
140 changes: 75 additions & 65 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ include = ["api/generated/**/*"]

[tool.poetry.dependencies]
python = ">=3.8,<4.0"
protobuf = "^4.22.3"
protobuf = ">=3.19.6"
grpcio-tools = ">=1.46.0"

[tool.poetry.group.dev]
Expand Down
4 changes: 2 additions & 2 deletions scripts/proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def generate():
for proto_file in proto_sources:
cmd_run(
f"python -m grpc_tools.protoc --proto_path={PROTO_ROOT}"
f" --python_out={GENERATED_PROTO_DIR} --pyi_out={GENERATED_PROTO_DIR} "
f"--grpc_python_out={GENERATED_PROTO_DIR} {proto_file}",
f" --python_out={GENERATED_PROTO_DIR}"
f" --grpc_python_out={GENERATED_PROTO_DIR} {proto_file}",
shell=True,
check=True,
)
Expand Down
4 changes: 1 addition & 3 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@
class AuthGatewayTest(TestCase):
def setUp(self) -> None:
self.done_future = MagicMock(grpc.Future)
self.client_config = ClientConfig(
server_url="localhost:5000", project_name="db1"
)
self.client_config = ClientConfig(server_url="localhost:5000", project="db1")

def test_get_access_token_with_valid_token_refresh_window(
self, channel_ready_future, grpc_auth
Expand Down
90 changes: 68 additions & 22 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,49 @@ class TigrisClientTest(TestCase):
def setUp(self) -> None:
self.done_future = MagicMock(grpc.Future)

def test_without_config(self, ready_future):
def test_init_with_none(self, ready_future):
ready_future.return_value = self.done_future
client = TigrisClient()
self.assertEqual("localhost:8081", client.config.server_url)
self.assertIsNone(client.config.client_id)
self.assertIsNone(client.config.client_secret)
self.assertIsNone(client.config.project_name)
self.assertEqual("", client.config.branch)
with self.assertRaisesRegex(
ValueError, "`TIGRIS_PROJECT` environment variable"
):
TigrisClient()

def test_init_with_complete_dict(self, ready_future):
ready_future.return_value = self.done_future
client = TigrisClient(
{
"server_url": "uri",
"project": "project",
"client_id": "client",
"client_secret": "secret",
"branch": "branch",
}
)
self.assertEqual("uri:443", client.config.server_url)
self.assertEqual("client", client.config.client_id)
self.assertEqual("secret", client.config.client_secret)
self.assertEqual("project", client.config.project)
self.assertEqual("branch", client.config.branch)

@patch.dict(
os.environ,
{
"TIGRIS_URI": "uri_env",
"TIGRIS_URI": "localhost:5000",
"TIGRIS_PROJECT": "project_env",
"TIGRIS_CLIENT_ID": "client_env",
"TIGRIS_CLIENT_SECRET": "secret_env",
"TIGRIS_DB_BRANCH": "branch_env",
},
)
def test_with_config(self, ready_future):
def test_init_with_partial_dict(self, ready_future):
ready_future.return_value = self.done_future
client = TigrisClient({"project": "p1"})
self.assertEqual("localhost:5000", client.config.server_url)
self.assertEqual("p1", client.config.project)

def test_init_with_config(self, ready_future):
ready_future.return_value = self.done_future
client = TigrisClient(
config=ClientConfig(
conf=ClientConfig(
server_url="test_url",
project_name="test_project",
project="test_project",
client_id="test_client_id",
client_secret="test_client_secret",
branch="test_branch",
Expand All @@ -46,9 +64,24 @@ def test_with_config(self, ready_future):
self.assertEqual("test_url:443", client.config.server_url)
self.assertEqual("test_client_id", client.config.client_id)
self.assertEqual("test_client_secret", client.config.client_secret)
self.assertEqual("test_project", client.config.project_name)
self.assertEqual("test_project", client.config.project)
self.assertEqual("test_branch", client.config.branch)

def test_init_local_dev(self, ready_future):
ready_future.return_value = self.done_future
client = TigrisClient(
{"server_url": "localhost:5000", "project": "test_project"}
)
self.assertEqual("localhost:5000", client.config.server_url)
self.assertEqual("test_project", client.config.project)

def test_init_failing_validation(self, ready_future):
ready_future.return_value = self.done_future
with self.assertRaisesRegex(
ValueError, "`TIGRIS_PROJECT` environment variable"
):
TigrisClient({"server_url": "localhost:5000"})

@patch.dict(
os.environ,
{
Expand All @@ -65,23 +98,36 @@ def test_with_env_vars(self, ready_future):
self.assertEqual("uri_env:443", client.config.server_url)
self.assertEqual("client_env", client.config.client_id)
self.assertEqual("secret_env", client.config.client_secret)
self.assertEqual("project_env", client.config.project_name)
self.assertEqual("project_env", client.config.project)
self.assertEqual("branch_env", client.config.branch)

def test_strip_https(self, ready_future):
ready_future.return_value = self.done_future
client = TigrisClient(config=ClientConfig(server_url="https://my.tigris.dev"))
conf = ClientConfig(
server_url="https://my.tigris.dev",
project="p1",
client_id="id",
client_secret="secret",
)
client = TigrisClient(conf)
self.assertEqual("my.tigris.dev:443", client.config.server_url)
client = TigrisClient(config=ClientConfig(server_url="http://my.tigris.dev"))

conf.server_url = "http://my.tigris.dev"
client = TigrisClient(conf)
self.assertEqual("my.tigris.dev:443", client.config.server_url)

def test_get_db(self, ready_future):
ready_future.return_value = self.done_future
client = TigrisClient()
client = TigrisClient(ClientConfig(project="p1", server_url="localhost:5000"))
self.assertEqual(client.config.branch, client.get_db().branch)
self.assertEqual(client.config.project_name, client.get_db().project)
self.assertEqual(client.config.project, client.get_db().project)

def test_get_search(self, ready_future):
ready_future.return_value = self.done_future
client = TigrisClient()
self.assertEqual(client.config.project_name, client.get_search().project)
client = TigrisClient(ClientConfig(project="p1", server_url="localhost:5000"))
self.assertEqual(client.config.project, client.get_search().project)

def test_get_vector_search(self, ready_future):
ready_future.return_value = self.done_future
client = TigrisClient(ClientConfig(project="p1", server_url="localhost:5000"))
self.assertEqual("v1", client.get_vector_store("v1").name)
6 changes: 2 additions & 4 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,12 @@
@patch("api.generated.server.v1.search_pb2_grpc.SearchStub")
class SearchTest(TestCase):
def setUp(self) -> None:
self.client_config = ClientConfig(
server_url="localhost:5000", project_name="db1"
)
self.client_config = ClientConfig(server_url="localhost:5000", project="db1")

def test_get_index(self, grpc_search):
mock_grpc = grpc_search()
search = Search(mock_grpc, self.client_config)
search_index = search.get_index("test-index")

self.assertEqual("test-index", search_index.name)
self.assertEqual(self.client_config.project_name, search_index.project)
self.assertEqual(self.client_config.project, search_index.project)
4 changes: 1 addition & 3 deletions tests/test_search_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@
@patch("api.generated.server.v1.search_pb2_grpc.SearchStub")
class SearchIndexTest(TestCase):
def setUp(self) -> None:
self.client_config = ClientConfig(
server_url="localhost:5000", project_name="db1"
)
self.client_config = ClientConfig(server_url="localhost:5000", project="db1")
self.index_name = "catalog"

def test_search(self, grpc_search):
Expand Down
145 changes: 145 additions & 0 deletions tests/test_types_client_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import os
from unittest import TestCase
from unittest.mock import patch

from tigrisdb.types import ClientConfig


class ClientConfigTest(TestCase):
def setUp(self) -> None:
pass

def test_init_with_none(self):
conf = ClientConfig()
self.assertIsNone(conf.project)
self.assertEqual(conf.server_url, "api.preview.tigrisdata.cloud")
self.assertIsNone(conf.client_id)
self.assertIsNone(conf.client_secret)
self.assertEqual(conf.branch, "")
self.assertFalse(conf.is_local_dev())

@patch.dict(
os.environ,
{
"TIGRIS_URI": "uri_env",
"TIGRIS_PROJECT": "project_env",
"TIGRIS_CLIENT_ID": "client_env",
"TIGRIS_CLIENT_SECRET": "secret_env",
"TIGRIS_DB_BRANCH": "branch_env",
},
)
def test_init_with_all_args(self):
conf = ClientConfig(
server_url="uri",
project="project",
client_id="client",
client_secret="secret",
branch="branch",
)
self.assertEqual(conf.server_url, "uri")
self.assertEqual(conf.project, "project")
self.assertEqual(conf.client_id, "client")
self.assertEqual(conf.client_secret, "secret")
self.assertEqual(conf.branch, "branch")

@patch.dict(
os.environ,
{
"TIGRIS_URI": "uri_env",
"TIGRIS_PROJECT": "project_env",
"TIGRIS_CLIENT_ID": "client_env",
"TIGRIS_CLIENT_SECRET": "secret_env",
"TIGRIS_DB_BRANCH": "branch_env",
},
)
def test_init_with_all_env(self):
conf = ClientConfig()
self.assertEqual(conf.server_url, "uri_env")
self.assertEqual(conf.project, "project_env")
self.assertEqual(conf.client_id, "client_env")
self.assertEqual(conf.client_secret, "secret_env")
self.assertEqual(conf.branch, "branch_env")

@patch.dict(
os.environ,
{
"TIGRIS_CLIENT_ID": "client_env",
"TIGRIS_DB_BRANCH": "branch_env",
},
)
def test_init_with_partial_env(self):
conf = ClientConfig(project="project")
self.assertEqual(conf.project, "project")
self.assertEqual(conf.client_id, "client_env")
self.assertIsNone(conf.client_secret)
self.assertEqual(conf.branch, "branch_env")

@patch.dict(
os.environ,
{
"TIGRIS_URI": "uri_env",
"TIGRIS_PROJECT": "project_env",
"TIGRIS_CLIENT_ID": "client_env",
"TIGRIS_CLIENT_SECRET": "secret_env",
"TIGRIS_DB_BRANCH": "branch_env",
},
)
def test_merge_override_all(self):
conf = ClientConfig(
server_url="uri",
project="project",
client_id="client",
client_secret="secret",
branch="branch",
)
conf.merge(
project="project_dict",
client_id="client_dict",
client_secret="secret_dict",
server_url="uri_dict",
branch="branch_dict",
)
self.assertEqual(conf.server_url, "uri_dict")
self.assertEqual(conf.project, "project_dict")
self.assertEqual(conf.client_id, "client_dict")
self.assertEqual(conf.client_secret, "secret_dict")
self.assertEqual(conf.branch, "branch_dict")

def test_local_dev(self):
cases = [
(ClientConfig(), False),
(ClientConfig(server_url="localhost:1234"), True),
(ClientConfig(server_url="127.0.0.1"), True),
(ClientConfig(server_url="[::1]"), True),
(ClientConfig(server_url="https://tigrisdb-local-server:1234"), True),
(ClientConfig(server_url="https://api.tigris.cloud"), False),
]
for conf, expected in cases:
with self.subTest(conf.server_url):
self.assertEqual(expected, conf.is_local_dev())

def test_validate_without_project(self):
conf = ClientConfig()
with self.assertRaisesRegex(
ValueError, "`TIGRIS_PROJECT` environment variable"
):
conf.validate()

def test_validate_without_client_id(self):
conf = ClientConfig(project="project")
with self.assertRaisesRegex(
ValueError, "`TIGRIS_CLIENT_ID` environment variable"
):
conf.validate()

def test_validate_without_client_secret(self):
conf = ClientConfig(project="project", client_id="id")
with self.assertRaisesRegex(
ValueError, "`TIGRIS_CLIENT_SECRET` environment variable"
):
conf.validate()

def test_validate_no_error(self):
conf = ClientConfig(project="project", client_id="id", client_secret="secret")
conf.validate()
self.assertFalse(conf.is_local_dev())
10 changes: 3 additions & 7 deletions tests/test_vector_store.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from unittest import TestCase
from unittest.mock import MagicMock, Mock, call, patch
from unittest.mock import Mock, call

from tests import NotFoundRpcError
from tigrisdb.errors import TigrisServerError
Expand Down Expand Up @@ -28,12 +28,8 @@ class VectorStoreTest(TestCase):
def setUp(self) -> None:
self.mock_index = Mock(spec=SearchIndex)
self.mock_client = Mock(spec=Search)
with patch("tigrisdb.client.TigrisClient.__new__") as mock_tigris:
instance = MagicMock()
mock_tigris.return_value = instance
instance.get_search.return_value = self.mock_client
self.mock_client.get_index.return_value = self.mock_index
self.store = VectorStore("my_vectors")
self.mock_client.get_index.return_value = self.mock_index
self.store = VectorStore(self.mock_client, "my_vectors")

def test_add_documents_when_index_not_found(self):
# throw error on first call and succeed on second
Expand Down
6 changes: 6 additions & 0 deletions tigrisdb/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from tigrisdb.client import TigrisClient # noqa: F401
from tigrisdb.collection import Collection # noqa: F401
from tigrisdb.database import Database # noqa: F401
from tigrisdb.search import Search # noqa: F401
from tigrisdb.search_index import SearchIndex # noqa: F401
from tigrisdb.vector_store import VectorStore # noqa: F401
Loading

0 comments on commit 667d484

Please sign in to comment.