Skip to content

Commit

Permalink
feat(core): Add support for ollama module (#618)
Browse files Browse the repository at this point in the history
- Added a new class OllamaContainer with few methods to handle the
Ollama container.

- The `_check_and_add_gpu_capabilities` method checks if the host has
GPUs and adds the necessary capabilities to the container.

- The `commit_to_image` allows to save somehow the state of a container
into an image so that we can reuse it, especially for the ones having
some models pulled.
- Added tests to check the functionality of the new class.

> Note: I inspired myself from the java implementation of the Ollama
module.

Fixes #617

---------

Co-authored-by: David Ankin <daveankin@gmail.com>
  • Loading branch information
bricefotzo and alexanderankin committed Jun 27, 2024
1 parent ead0f79 commit 5442d05
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 1 deletion.
2 changes: 2 additions & 0 deletions modules/ollama/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.. autoclass:: testcontainers.ollama.OllamaContainer
.. title:: testcontainers.ollama.OllamaContainer
120 changes: 120 additions & 0 deletions modules/ollama/testcontainers/ollama/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

from os import PathLike
from typing import Any, Optional, TypedDict, Union

from docker.types.containers import DeviceRequest
from requests import get

from testcontainers.core.container import DockerContainer
from testcontainers.core.waiting_utils import wait_for_logs


class OllamaModel(TypedDict):
name: str
model: str
modified_at: str
size: int
digest: str
details: dict[str, Any]


class OllamaContainer(DockerContainer):
"""
Ollama Container
Example:
.. doctest::
>>> from testcontainers.ollama import OllamaContainer
>>> with OllamaContainer() as ollama:
... ollama.list_models()
[]
"""

OLLAMA_PORT = 11434

def __init__(
self,
image: str = "ollama/ollama:0.1.44",
ollama_dir: Optional[Union[str, PathLike]] = None,
**kwargs,
#
):
super().__init__(image=image, **kwargs)
self.ollama_dir = ollama_dir
self.with_exposed_ports(OllamaContainer.OLLAMA_PORT)
self._check_and_add_gpu_capabilities()

def _check_and_add_gpu_capabilities(self):
info = self.get_docker_client().client.info()
if "nvidia" in info["Runtimes"]:
self._kwargs = {**self._kwargs, "device_requests": DeviceRequest(count=-1, capabilities=[["gpu"]])}

def start(self) -> "OllamaContainer":
"""
Start the Ollama server
"""
if self.ollama_dir:
self.with_volume_mapping(self.ollama_dir, "/root/.ollama", "rw")
super().start()
wait_for_logs(self, "Listening on ", timeout=30)

return self

def get_endpoint(self):
"""
Return the endpoint of the Ollama server
"""
host = self.get_container_host_ip()
exposed_port = self.get_exposed_port(OllamaContainer.OLLAMA_PORT)
url = f"http://{host}:{exposed_port}"
return url

@property
def id(self) -> str:
"""
Return the container object
"""
return self._container.id

def pull_model(self, model_name: str) -> None:
"""
Pull a model from the Ollama server
Args:
model_name (str): Name of the model
"""
self.exec(f"ollama pull {model_name}")

def list_models(self) -> list[OllamaModel]:
endpoint = self.get_endpoint()
response = get(url=f"{endpoint}/api/tags")
response.raise_for_status()
return response.json().get("models", [])

def commit_to_image(self, image_name: str) -> None:
"""
Commit the current container to a new image
Args:
image_name (str): Name of the new image
"""
docker_client = self.get_docker_client()
existing_images = docker_client.client.images.list(name=image_name)
if not existing_images and self.id:
docker_client.client.containers.get(self.id).commit(
repository=image_name, conf={"Labels": {"org.testcontainers.session-id": ""}}
)
60 changes: 60 additions & 0 deletions modules/ollama/tests/test_ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import random
import string
from pathlib import Path

import requests
from testcontainers.ollama import OllamaContainer


def random_string(length=6):
return "".join(random.choices(string.ascii_lowercase, k=length))


def test_ollama_container():
with OllamaContainer() as ollama:
url = ollama.get_endpoint()
response = requests.get(url)
assert response.status_code == 200
assert response.text == "Ollama is running"


def test_with_default_config():
with OllamaContainer("ollama/ollama:0.1.26") as ollama:
ollama.start()
response = requests.get(f"{ollama.get_endpoint()}/api/version")
version = response.json().get("version")
assert version == "0.1.26"


def test_download_model_and_commit_to_image():
new_image_name = f"tc-ollama-allminilm-{random_string(length=4).lower()}"
with OllamaContainer("ollama/ollama:0.1.26") as ollama:
ollama.start()
# Pull the model
ollama.pull_model("all-minilm")

response = requests.get(f"{ollama.get_endpoint()}/api/tags")
model_name = ollama.list_models()[0].get("name")
assert "all-minilm" in model_name

# Commit the container state to a new image
ollama.commit_to_image(new_image_name)

# Verify the new image
with OllamaContainer(new_image_name) as ollama:
ollama.start()
response = requests.get(f"{ollama.get_endpoint()}/api/tags")
model_name = response.json().get("models", [])[0].get("name")
assert "all-minilm" in model_name


def test_models_saved_in_folder(tmp_path: Path):
with OllamaContainer("ollama/ollama:0.1.26", ollama_dir=tmp_path) as ollama:
assert len(ollama.list_models()) == 0
ollama.pull_model("all-minilm")
assert len(ollama.list_models()) == 1
assert "all-minilm" in ollama.list_models()[0].get("name")

with OllamaContainer("ollama/ollama:0.1.26", ollama_dir=tmp_path) as ollama:
assert len(ollama.list_models()) == 1
assert "all-minilm" in ollama.list_models()[0].get("name")
4 changes: 3 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ packages = [
{ include = "testcontainers", from = "modules/nats" },
{ include = "testcontainers", from = "modules/neo4j" },
{ include = "testcontainers", from = "modules/nginx" },
{ include = "testcontainers", from = "modules/ollama" },
{ include = "testcontainers", from = "modules/opensearch" },
{ include = "testcontainers", from = "modules/oracle-free" },
{ include = "testcontainers", from = "modules/postgres" },
Expand Down Expand Up @@ -127,6 +128,7 @@ nats = ["nats-py"]
neo4j = ["neo4j"]
nginx = []
opensearch = ["opensearch-py"]
ollama = []
oracle = ["sqlalchemy", "oracledb"]
oracle-free = ["sqlalchemy", "oracledb"]
postgres = []
Expand Down Expand Up @@ -272,6 +274,7 @@ mypy_path = [
# "modules/mysql",
# "modules/neo4j",
# "modules/nginx",
# "modules/ollama",
# "modules/opensearch",
# "modules/oracle",
# "modules/postgres",
Expand Down

0 comments on commit 5442d05

Please sign in to comment.