Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Add Weaviate module #492

Merged
merged 5 commits into from
Mar 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ testcontainers-python facilitates the use of Docker containers for functional an
modules/rabbitmq/README
modules/redis/README
modules/selenium/README
modules/weaviate/README

Getting Started
---------------
Expand Down
2 changes: 2 additions & 0 deletions modules/weaviate/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.. autoclass:: testcontainers.weaviate.WeaviateContainer
.. title:: testcontainers.weaviate.WeaviateContainer
178 changes: 178 additions & 0 deletions modules/weaviate/testcontainers/weaviate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from typing import TYPE_CHECKING, Optional

from requests import ConnectionError, get

from testcontainers.core.generic import DbContainer
from testcontainers.core.waiting_utils import wait_container_is_ready

if TYPE_CHECKING:
from requests import Response


class WeaviateContainer(DbContainer):
"""
Weaviate vector database container.

Arguments:
`image`
Docker image to use with Weaviate container.
`env_vars`
Additional environment variables to include with the container, e.g. ENABLE_MODULES list, QUERY_DEFAULTS_LIMIT setting.

Example:
This example shows how to start Weaviate container with defualt settings.

.. doctest::

>>> from testcontainers.weaviate import WeaviateContainer

>>> with WeaviateContainer() as container:
... with container.get_client() as client:
... client.is_live()
True

This example shows how to start Weaviate container with additinal settings.

.. doctest::

>>> from testcontainers.weaviate import WeaviateContainer

>>> with WeaviateContainer(
... env_vars={
... "ENABLE_MODULES": "backup-filesystem,text2vec-openai",
... "BACKUP_FILESYSTEM_PATH": "/tmp/backups",
... "QUERY_DEFAULTS_LIMIT": 100,
... }
... ) as container:
... with container.get_client() as client:
... client.is_live()
True
"""

def __init__(
self,
image: str = "semitechnologies/weaviate:1.24.5",
env_vars: Optional[dict[str, str]] = None,
**kwargs,
) -> None:
super().__init__(image, **kwargs)
self._http_port = 8080
self._grpc_port = 50051

self.with_command(f"--host 0.0.0.0 --scheme http --port {self._http_port}")
self.with_exposed_ports(self._http_port, self._grpc_port)

if env_vars is not None:
for key, value in env_vars.items():
self.with_env(key, value)

def _configure(self) -> None:
self.with_env("AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED", "true")
self.with_env("PERSISTENCE_DATA_PATH", "/var/lib/weaviate")

@wait_container_is_ready(ConnectionError)
def _connect(self) -> None:
url = f"http://{self.get_http_host()}:{self.get_http_port()}/v1/.well-known/ready"
response: Response = get(url)
response.raise_for_status()

def get_client(
self,
headers: Optional[dict[str, str]] = None,
):
"""
Get a `weaviate.WeaviateClient` instance associated with the container.

Arguments:
`headers`
Additional headers to include in the requests, e.g. API keys for third-party Cloud vectorization.

Returns:
WeaviateClient: An instance of the `weaviate.WeaviateClient` class.
"""

try:
import weaviate
except ImportError as e:
raise ImportError("To use the `get_client` method, you must install the `weaviate-client` package.") from e
return weaviate.connect_to_custom(
http_host=self.get_http_host(),
http_port=self.get_http_port(),
http_secure=self.get_http_secure(),
grpc_host=self.get_http_host(),
grpc_port=self.get_grpc_port(),
grpc_secure=self.get_grpc_secure(),
headers=headers,
)

def get_http_host(self) -> str:
"""
Get the HTTP host of Weaviate container.

Returns:
`str`
The HTTP host of Weaviate container.
"""
return f"{self.get_container_host_ip()}"

def get_http_port(self) -> int:
"""
Get the HTTP port of Weaviate container.

Returns:
`int`
The HTTP port of Weaviate container.
"""
return self.get_exposed_port(self._http_port)

def get_http_secure(self) -> bool:
"""
Get the HTTP secured setting of Weaviate container.

Returns:
`bool`
True if it's https.
"""
return False

def get_grpc_host(self) -> str:
"""
Get the gRPC host of Weaviate container.

Returns:
`str`
The gRPC host of Weaviate container.
"""
return f"{self.get_container_host_ip()}"

def get_grpc_port(self) -> int:
"""
Get the gRPC port of Weaviate container.

Returns:
`int`
The gRPC port of Weaviate container.
"""
return self.get_exposed_port(self._grpc_port)

def get_grpc_secure(self) -> bool:
"""
Get the gRPC secured setting of Weaviate container.

Returns:
`str`
True if the conntection is secured with SSL.
"""
return False
55 changes: 55 additions & 0 deletions modules/weaviate/tests/test_weaviate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from testcontainers.weaviate import WeaviateContainer
import weaviate


def test_docker_run_weaviate():
with WeaviateContainer() as container:
client = weaviate.connect_to_custom(
http_host=container.get_http_host(),
http_port=container.get_http_port(),
http_secure=container.get_http_secure(),
grpc_host=container.get_grpc_host(),
grpc_port=container.get_grpc_port(),
grpc_secure=container.get_grpc_secure(),
)

meta = client.get_meta()
assert len(meta.get("version")) > 0

client.close()


def test_docker_run_weaviate_with_client():
with WeaviateContainer() as container:
with container.get_client() as client:
assert client.is_live()

meta = client.get_meta()
assert len(meta.get("version")) > 0


def test_docker_run_weaviate_with_modules():
enable_modules = [
"backup-filesystem",
"text2vec-openai",
"text2vec-cohere",
"text2vec-huggingface",
"generative-openai",
]
with WeaviateContainer(
env_vars={
"ENABLE_MODULES": ",".join(enable_modules),
"BACKUP_FILESYSTEM_PATH": "/tmp/backups",
}
) as container:
with container.get_client() as client:
assert client.is_live()

meta = client.get_meta()
assert len(meta.get("version")) > 0

modules = meta.get("modules")
assert len(modules) == len(enable_modules)

for name in enable_modules:
assert len(modules[name]) > 0
Loading