Skip to content

Commit

Permalink
Add generative openai support
Browse files Browse the repository at this point in the history
  • Loading branch information
dirkkul committed Feb 24, 2023
1 parent 10ec692 commit da68913
Show file tree
Hide file tree
Showing 8 changed files with 167 additions and 20 deletions.
23 changes: 23 additions & 0 deletions ci/docker-compose_openai.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
---
version: '3.4'
services:
weaviate_openai:
command:
- --host
- 0.0.0.0
- --port
- '8086'
- --scheme
- http
image:
semitechnologies/weaviate:1.17.4
ports:
- 8086:8086
restart: on-failure:0
environment:
QUERY_DEFAULTS_LIMIT: 25
AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true'
PERSISTENCE_DATA_PATH: '/var/lib/weaviate'
DEFAULT_VECTORIZER_MODULE: 'text2vec-openai'
ENABLE_MODULES: 'text2vec-openai,generative-openai'
CLUSTER_HOSTNAME: 'node1'
1 change: 1 addition & 0 deletions ci/start_weaviate.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ nohup docker-compose -f ci/docker-compose-azure.yml up -d
nohup docker-compose -f ci/docker-compose-okta-cc.yml up -d
nohup docker-compose -f ci/docker-compose-okta-users.yml up -d
nohup docker-compose -f ci/docker-compose-wcs.yml up -d
nohup docker-compose -f ci/docker-compose_openai.yml up -d
1 change: 1 addition & 0 deletions ci/stop_weaviate.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ docker-compose -f ci/docker-compose-azure.yml down --remove-orphans
docker-compose -f ci/docker-compose-okta-cc.yml down --remove-orphans
docker-compose -f ci/docker-compose-okta-users.yml down --remove-orphans
docker-compose -f ci/docker-compose-wcs.yml down --remove-orphans
docker-compose -f ci/docker-compose_openai.yml down --remove-orphans
47 changes: 47 additions & 0 deletions integration/test_graphql.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import pytest

import weaviate
Expand Down Expand Up @@ -114,3 +116,48 @@ def test_hybrid_bm25(client):
# will find more results. "The Crusty Crab" is still first, because it matches with the BM25 search
assert len(result["data"]["Get"]["Ship"]) >= 1
assert result["data"]["Get"]["Ship"][0]["name"] == "The Crusty Crab"


@pytest.mark.parametrize(
"single,grouped",
[
("Describe the following as a Facebook Ad: {review}", None),
(None, "Describe the following as a LinkedIn Ad: {review}"),
(
"Describe the following as a Twitter Ad: {review}",
"Describe the following as a Mastodon Ad: {review}",
),
],
)
def test_generative_openai(single: str, grouped: str):
"""Test client credential flow with various providers."""
api_key = os.environ.get("OPENAI_APIKEY")
if api_key is None:
pytest.skip("No OpenAI API key found.")

client = weaviate.Client(
"http://127.0.0.1:8086", additional_headers={"X-OpenAI-Api-Key": api_key}
)
client.schema.delete_all()
wine_class = {
"class": "Wine",
"properties": [
{"name": "name", "dataType": ["string"]},
{"name": "review", "dataType": ["string"]},
],
}
client.schema.create_class(wine_class)
client.data_object.create(
data_object={"name": "Super expensive wine", "review": "Tastes like a fresh ocean breeze"},
class_name="Wine",
)
client.data_object.create(
data_object={"name": "cheap wine", "review": "Tastes like forest"}, class_name="Wine"
)

result = (
client.query.get("Wine", ["name", "review"])
.with_generate(single_prompt=single, grouped_task=grouped)
.do()
)
assert result["data"]["Get"]["Wine"][0]["_additional"]["generate"]["error"] is None
32 changes: 32 additions & 0 deletions test/gql/test_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,38 @@ def test_hybrid(query: str, vector: Optional[List[float]], alpha: Optional[float
assert str(hybrid) == expected


@pytest.mark.parametrize(
"single_prompt,grouped_task,expected",
[
(
"What is the meaning of life?",
None,
"""generate(singleResult:{prompt:"What is the meaning of life?"}){error singleResult} """,
),
(
None,
"Explain why these magazines or newspapers are about finance",
"""generate(groupedResult:{task:"Explain why these magazines or newspapers are about finance"}){error groupedResult} """,
),
(
"What is the meaning of life?",
"Explain why these magazines or newspapers are about finance",
"""generate(singleResult:{prompt:"What is the meaning of life?"}groupedResult:{task:"Explain why these magazines or newspapers are about finance"}){error singleResult groupedResult} """,
),
],
)
def test_generative(single_prompt: str, grouped_task: str, expected: str):
query = GetBuilder("Person", "name", None).with_generate(single_prompt, grouped_task).build()
expected_query = "{Get{Person{name _additional {" + expected + "}}}}"
assert query == expected_query


@pytest.mark.parametrize("single_prompt,grouped_task", [(123, None), (None, None), (None, 123)])
def test_generative_type(single_prompt: str, grouped_task: str):
with pytest.raises(TypeError):
GetBuilder("Person", "name", None).with_generate(single_prompt, grouped_task).build()


class TestGetBuilder(unittest.TestCase):
def test___init__(self):
"""
Expand Down
19 changes: 2 additions & 17 deletions weaviate/data/replication/replication.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,6 @@
from enum import Enum, EnumMeta, auto
from enum import auto


# MetaEnum and BaseEnum are required to support `in` statements:
# 'ALL' in ConsistencyLevel == True
# 12345 in ConsistencyLevel == False
class MetaEnum(EnumMeta):
def __contains__(cls, item):
try:
# when item is type ConsistencyLevel
return item.name in cls.__members__.keys()
except AttributeError:
# when item is type str
return item in cls.__members__.keys()


class BaseEnum(Enum, metaclass=MetaEnum):
pass
from weaviate.util import BaseEnum


class ConsistencyLevel(str, BaseEnum):
Expand Down
46 changes: 43 additions & 3 deletions weaviate/gql/get.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
GraphQL `Get` command.
"""
from dataclasses import dataclass
from enum import auto
from json import dumps
from typing import List, Union, Optional, Dict, Tuple

from weaviate.connect import Connection
from weaviate.gql.filter import (
Where,
NearText,
Expand All @@ -15,8 +18,7 @@
NearImage,
Sort,
)
from weaviate.connect import Connection
from weaviate.util import image_encoder_b64, _capitalize_first_letter
from weaviate.util import image_encoder_b64, _capitalize_first_letter, BaseEnum


@dataclass
Expand Down Expand Up @@ -48,6 +50,11 @@ def __str__(self) -> str:
return "hybrid:{" + ret + "}"


class GenerativeType(str, BaseEnum):
SINGLE = auto()
GROUPED = auto()


class GetBuilder(GraphQL):
"""
GetBuilder class used to create GraphQL queries.
Expand Down Expand Up @@ -922,12 +929,45 @@ def with_hybrid(
Vector that is searched for. If 'None', weaviate will use the configured text-to-vector module to create a
vector from the "query" field.
By default, None
"""
self._hybrid = Hybrid(query, alpha, vector)
self._contains_filter = True
return self

def with_generate(
self, single_prompt: Optional[str] = None, grouped_task: Optional[str] = None
) -> "GetBuilder":
"""Generate responses using the OpenAI generative search.
Parameters
----------
grouped_task: Optional[str]
The task to generate a grouped response.
single_prompt: Optional[str]
The prompt to generate a single response.
"""
if single_prompt is None and grouped_task is None:
raise TypeError(
"Either parameter grouped_result_task or single_result_prompt must be not None."
)
if (single_prompt is not None and not isinstance(single_prompt, str)) or (
grouped_task is not None and not isinstance(grouped_task, str)
):
raise TypeError("prompts and tasks must be of type str.")

results: List[str] = ["error"]
task_and_prompt = ""
if single_prompt is not None:
results.append("singleResult")
task_and_prompt += f'singleResult:{{prompt:"{single_prompt}"}}'
if grouped_task is not None:
results.append("groupedResult")
task_and_prompt += f'groupedResult:{{task:"{grouped_task}"}}'

self._additional["__one_level"].add(f'generate({task_and_prompt}){{{" ".join(results)}}}')

return self

def build(self) -> str:
"""
Build query filter as a string.
Expand Down
18 changes: 18 additions & 0 deletions weaviate/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import os
import uuid as uuid_lib
from enum import Enum, EnumMeta
from io import BufferedReader
from numbers import Real
from typing import Union, Sequence, Any, Optional, List, Dict
Expand All @@ -15,6 +16,23 @@
from weaviate.exceptions import SchemaValidationException


# MetaEnum and BaseEnum are required to support `in` statements:
# 'ALL' in ConsistencyLevel == True
# 12345 in ConsistencyLevel == False
class MetaEnum(EnumMeta):
def __contains__(cls, item):
try:
# when item is type ConsistencyLevel
return item.name in cls.__members__.keys()
except AttributeError:
# when item is type str
return item in cls.__members__.keys()


class BaseEnum(Enum, metaclass=MetaEnum):
pass


def image_encoder_b64(image_or_image_path: Union[str, BufferedReader]) -> str:
"""
Encode a image in a Weaviate understandable format from a binary read file or by providing
Expand Down

0 comments on commit da68913

Please sign in to comment.