diff --git a/ci/docker-compose_openai.yml b/ci/docker-compose_openai.yml new file mode 100644 index 000000000..b730a3bbf --- /dev/null +++ b/ci/docker-compose_openai.yml @@ -0,0 +1,23 @@ +--- +version: '3.4' +services: + weaviate_openai: + command: + - --host + - 0.0.0.0 + - --port + - '8086' + - --scheme + - http + image: + semitechnologies/weaviate:1.17.4 + ports: + - 8086:8086 + restart: on-failure:0 + environment: + QUERY_DEFAULTS_LIMIT: 25 + AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true' + PERSISTENCE_DATA_PATH: '/var/lib/weaviate' + DEFAULT_VECTORIZER_MODULE: 'text2vec-openai' + ENABLE_MODULES: 'text2vec-openai,generative-openai' + CLUSTER_HOSTNAME: 'node1' \ No newline at end of file diff --git a/ci/start_weaviate.sh b/ci/start_weaviate.sh index a1ecf6cf5..8e58383d3 100755 --- a/ci/start_weaviate.sh +++ b/ci/start_weaviate.sh @@ -6,3 +6,4 @@ nohup docker-compose -f ci/docker-compose-azure.yml up -d nohup docker-compose -f ci/docker-compose-okta-cc.yml up -d nohup docker-compose -f ci/docker-compose-okta-users.yml up -d nohup docker-compose -f ci/docker-compose-wcs.yml up -d +nohup docker-compose -f ci/docker-compose_openai.yml up -d diff --git a/ci/stop_weaviate.sh b/ci/stop_weaviate.sh index 1313dfea5..dbc39de74 100755 --- a/ci/stop_weaviate.sh +++ b/ci/stop_weaviate.sh @@ -5,3 +5,4 @@ docker-compose -f ci/docker-compose-azure.yml down --remove-orphans docker-compose -f ci/docker-compose-okta-cc.yml down --remove-orphans docker-compose -f ci/docker-compose-okta-users.yml down --remove-orphans docker-compose -f ci/docker-compose-wcs.yml down --remove-orphans +docker-compose -f ci/docker-compose_openai.yml down --remove-orphans diff --git a/integration/test_graphql.py b/integration/test_graphql.py index 1bdc32fce..fe6706bdc 100644 --- a/integration/test_graphql.py +++ b/integration/test_graphql.py @@ -1,3 +1,5 @@ +import os + import pytest import weaviate @@ -114,3 +116,48 @@ def test_hybrid_bm25(client): # will find more results. "The Crusty Crab" is still first, because it matches with the BM25 search assert len(result["data"]["Get"]["Ship"]) >= 1 assert result["data"]["Get"]["Ship"][0]["name"] == "The Crusty Crab" + + +@pytest.mark.parametrize( + "single,grouped", + [ + ("Describe the following as a Facebook Ad: {review}", None), + (None, "Describe the following as a LinkedIn Ad: {review}"), + ( + "Describe the following as a Twitter Ad: {review}", + "Describe the following as a Mastodon Ad: {review}", + ), + ], +) +def test_generative_openai(single: str, grouped: str): + """Test client credential flow with various providers.""" + api_key = os.environ.get("OPENAI_APIKEY") + if api_key is None: + pytest.skip("No OpenAI API key found.") + + client = weaviate.Client( + "http://127.0.0.1:8086", additional_headers={"X-OpenAI-Api-Key": api_key} + ) + client.schema.delete_all() + wine_class = { + "class": "Wine", + "properties": [ + {"name": "name", "dataType": ["string"]}, + {"name": "review", "dataType": ["string"]}, + ], + } + client.schema.create_class(wine_class) + client.data_object.create( + data_object={"name": "Super expensive wine", "review": "Tastes like a fresh ocean breeze"}, + class_name="Wine", + ) + client.data_object.create( + data_object={"name": "cheap wine", "review": "Tastes like forest"}, class_name="Wine" + ) + + result = ( + client.query.get("Wine", ["name", "review"]) + .with_generate(single_prompt=single, grouped_task=grouped) + .do() + ) + assert result["data"]["Get"]["Wine"][0]["_additional"]["generate"]["error"] is None diff --git a/test/gql/test_get.py b/test/gql/test_get.py index cd6e3f3da..9f58930bf 100644 --- a/test/gql/test_get.py +++ b/test/gql/test_get.py @@ -42,6 +42,38 @@ def test_hybrid(query: str, vector: Optional[List[float]], alpha: Optional[float assert str(hybrid) == expected +@pytest.mark.parametrize( + "single_prompt,grouped_task,expected", + [ + ( + "What is the meaning of life?", + None, + """generate(singleResult:{prompt:"What is the meaning of life?"}){error singleResult} """, + ), + ( + None, + "Explain why these magazines or newspapers are about finance", + """generate(groupedResult:{task:"Explain why these magazines or newspapers are about finance"}){error groupedResult} """, + ), + ( + "What is the meaning of life?", + "Explain why these magazines or newspapers are about finance", + """generate(singleResult:{prompt:"What is the meaning of life?"}groupedResult:{task:"Explain why these magazines or newspapers are about finance"}){error singleResult groupedResult} """, + ), + ], +) +def test_generative(single_prompt: str, grouped_task: str, expected: str): + query = GetBuilder("Person", "name", None).with_generate(single_prompt, grouped_task).build() + expected_query = "{Get{Person{name _additional {" + expected + "}}}}" + assert query == expected_query + + +@pytest.mark.parametrize("single_prompt,grouped_task", [(123, None), (None, None), (None, 123)]) +def test_generative_type(single_prompt: str, grouped_task: str): + with pytest.raises(TypeError): + GetBuilder("Person", "name", None).with_generate(single_prompt, grouped_task).build() + + class TestGetBuilder(unittest.TestCase): def test___init__(self): """ diff --git a/weaviate/data/replication/replication.py b/weaviate/data/replication/replication.py index 483e0586f..3087be4ba 100644 --- a/weaviate/data/replication/replication.py +++ b/weaviate/data/replication/replication.py @@ -1,21 +1,6 @@ -from enum import Enum, EnumMeta, auto +from enum import auto - -# MetaEnum and BaseEnum are required to support `in` statements: -# 'ALL' in ConsistencyLevel == True -# 12345 in ConsistencyLevel == False -class MetaEnum(EnumMeta): - def __contains__(cls, item): - try: - # when item is type ConsistencyLevel - return item.name in cls.__members__.keys() - except AttributeError: - # when item is type str - return item in cls.__members__.keys() - - -class BaseEnum(Enum, metaclass=MetaEnum): - pass +from weaviate.util import BaseEnum class ConsistencyLevel(str, BaseEnum): diff --git a/weaviate/gql/get.py b/weaviate/gql/get.py index 054af8e16..f8b2d284f 100644 --- a/weaviate/gql/get.py +++ b/weaviate/gql/get.py @@ -2,8 +2,11 @@ GraphQL `Get` command. """ from dataclasses import dataclass +from enum import auto from json import dumps from typing import List, Union, Optional, Dict, Tuple + +from weaviate.connect import Connection from weaviate.gql.filter import ( Where, NearText, @@ -15,8 +18,7 @@ NearImage, Sort, ) -from weaviate.connect import Connection -from weaviate.util import image_encoder_b64, _capitalize_first_letter +from weaviate.util import image_encoder_b64, _capitalize_first_letter, BaseEnum @dataclass @@ -48,6 +50,11 @@ def __str__(self) -> str: return "hybrid:{" + ret + "}" +class GenerativeType(str, BaseEnum): + SINGLE = auto() + GROUPED = auto() + + class GetBuilder(GraphQL): """ GetBuilder class used to create GraphQL queries. @@ -922,12 +929,45 @@ def with_hybrid( Vector that is searched for. If 'None', weaviate will use the configured text-to-vector module to create a vector from the "query" field. By default, None - """ self._hybrid = Hybrid(query, alpha, vector) self._contains_filter = True return self + def with_generate( + self, single_prompt: Optional[str] = None, grouped_task: Optional[str] = None + ) -> "GetBuilder": + """Generate responses using the OpenAI generative search. + + Parameters + ---------- + grouped_task: Optional[str] + The task to generate a grouped response. + single_prompt: Optional[str] + The prompt to generate a single response. + """ + if single_prompt is None and grouped_task is None: + raise TypeError( + "Either parameter grouped_result_task or single_result_prompt must be not None." + ) + if (single_prompt is not None and not isinstance(single_prompt, str)) or ( + grouped_task is not None and not isinstance(grouped_task, str) + ): + raise TypeError("prompts and tasks must be of type str.") + + results: List[str] = ["error"] + task_and_prompt = "" + if single_prompt is not None: + results.append("singleResult") + task_and_prompt += f'singleResult:{{prompt:"{single_prompt}"}}' + if grouped_task is not None: + results.append("groupedResult") + task_and_prompt += f'groupedResult:{{task:"{grouped_task}"}}' + + self._additional["__one_level"].add(f'generate({task_and_prompt}){{{" ".join(results)}}}') + + return self + def build(self) -> str: """ Build query filter as a string. diff --git a/weaviate/util.py b/weaviate/util.py index cd0a1a5df..50a25253e 100644 --- a/weaviate/util.py +++ b/weaviate/util.py @@ -5,6 +5,7 @@ import json import os import uuid as uuid_lib +from enum import Enum, EnumMeta from io import BufferedReader from numbers import Real from typing import Union, Sequence, Any, Optional, List, Dict @@ -15,6 +16,23 @@ from weaviate.exceptions import SchemaValidationException +# MetaEnum and BaseEnum are required to support `in` statements: +# 'ALL' in ConsistencyLevel == True +# 12345 in ConsistencyLevel == False +class MetaEnum(EnumMeta): + def __contains__(cls, item): + try: + # when item is type ConsistencyLevel + return item.name in cls.__members__.keys() + except AttributeError: + # when item is type str + return item in cls.__members__.keys() + + +class BaseEnum(Enum, metaclass=MetaEnum): + pass + + def image_encoder_b64(image_or_image_path: Union[str, BufferedReader]) -> str: """ Encode a image in a Weaviate understandable format from a binary read file or by providing