diff --git a/examples/openai_examples/create_image.py b/examples/openai_examples/create_image.py new file mode 100644 index 00000000..6fac11ea --- /dev/null +++ b/examples/openai_examples/create_image.py @@ -0,0 +1,75 @@ +import time +import base64 +from io import BytesIO +from PIL import Image + +from gptcache.adapter import openai +from gptcache.processor.pre import get_prompt +from gptcache import cache + +from gptcache.embedding import Onnx +from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation +from gptcache.manager import get_data_manager, CacheBase, VectorBase + + +onnx = Onnx() +cache_base = CacheBase('sqlite') +vector_base = VectorBase('milvus', host='localhost', port='19530', collection_name='gptcache_image', dimension=onnx.dimension) +data_manager = get_data_manager(cache_base, vector_base) + +cache.init( + pre_embedding_func=get_prompt, + embedding_func=onnx.to_embeddings, + data_manager=data_manager, + similarity_evaluation=SearchDistanceEvaluation(), + ) +cache.set_openai_key() + +##################### Create an image with prompt1 ################### +prompt1 = 'a cat sitting besides a dog' +size1 = '256x256' + +start = time.time() +response1 = openai.Image.create( + prompt=prompt1, + size=size1, + # response_format='b64_json' + response_format='b64_json' + ) +end = time.time() +print('Time elapsed:', end - start) + +# img = Image.open(response['data'][0]['url']) +# print(img.size) + +img_b64_1 = response1['data'][0]['b64_json'] +img_bytes_1 = base64.b64decode((img_b64_1)) +img_file_1 = BytesIO(img_bytes_1) # convert image to file-like object +img_1 = Image.open(img_file_1) # convert image to PIL +assert img_1.size == tuple([int(x) for x in size1.split('x')]), \ + 'Expected to generate an image of size {size1} but got {img_1.size}.' + + +##################### Create an image with prompt2 ##################### +prompt2 = 'a dog sitting besides a cat' +size2 = '512x512' + +start = time.time() +response2 = openai.Image.create( + prompt=prompt2, + size=size2, + # response_format='b64_json' + response_format='b64_json' + ) +end = time.time() +print('Time elapsed:', end - start) + +# img = Image.open(response['data'][0]['url']) +# print(img.size) + +img_b64_2 = response2['data'][0]['b64_json'] +img_bytes_2 = base64.b64decode((img_b64_2)) +img_file_2 = BytesIO(img_bytes_2) # convert image to file-like object +img_2 = Image.open(img_file_2) # convert image to PIL +assert img_2.size == tuple([int(x) for x in size2.split('x')]), \ + f'Expected to generate an image of size {size2} but got {img_2.size}.' diff --git a/gptcache/adapter/openai.py b/gptcache/adapter/openai.py index 5ea1e8ac..96d67f73 100644 --- a/gptcache/adapter/openai.py +++ b/gptcache/adapter/openai.py @@ -1,15 +1,27 @@ import time from typing import Iterator +import base64 +from io import BytesIO +import os + import openai + from gptcache import CacheError from gptcache.adapter.adapter import adapt from gptcache.utils.response import ( get_stream_message_from_openai_answer, get_message_from_openai_answer, get_text_from_openai_answer, + get_image_from_openai_b64, + get_image_from_openai_url ) +from gptcache.utils import import_pillow + +import_pillow() + +from PIL import Image as PILImage # pylint: disable=C0413 class ChatCompletion(openai.ChatCompletion): @@ -53,6 +65,36 @@ def cache_data_convert(cache_data): **kwargs ) +class Image(openai.Image): + """Openai Image Wrapper""" + + @classmethod + def create(cls, *args, **kwargs): + def llm_handler(*llm_args, **llm_kwargs): + try: + return openai.Image.create(*llm_args, **llm_kwargs) + except Exception as e: + raise CacheError("openai error") from e + + def cache_data_convert(cache_data): + return construct_image_create_resp_from_cache( + image_data=cache_data, + response_format=kwargs["response_format"], + size=kwargs["size"] + ) + + def update_cache_callback(llm_data, update_cache_func): + if kwargs["response_format"] == "b64_json": + update_cache_func(get_image_from_openai_b64(llm_data)) + return llm_data + elif kwargs["response_format"] == "url": + update_cache_func(get_image_from_openai_url(llm_data)) + return llm_data + + return adapt( + llm_handler, cache_data_convert, update_cache_callback, *args, **kwargs + ) + def construct_resp_from_cache(return_message): return { @@ -141,3 +183,33 @@ def construct_text_from_cache(return_text): "usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0}, "object": "text_completion", } + + +def construct_image_create_resp_from_cache(image_data, response_format, size): + img_bytes = base64.b64decode((image_data)) + img_file = BytesIO(img_bytes) # convert image to file-like object + img = PILImage.open(img_file) + new_size = tuple(int(a) for a in size.split("x")) + if new_size != img.size: + img = img.resize(new_size) + buffered = BytesIO() + img.save(buffered, format="JPEG") + else: + buffered = img_file + + if response_format == "url": + target_url = os.path.abspath(str(int(time.time())) + ".jpeg") + with open(target_url, "wb") as f: + f.write(buffered.getvalue()) + image_data = target_url + elif response_format == "b64_json": + image_data = base64.b64encode(buffered.getvalue()) + else: + raise AttributeError(f"Invalid response_format: {response_format} is not one of ['url', 'b64_json']") + + return { + "created": int(time.time()), + "data": [ + {response_format: image_data} + ] + } diff --git a/gptcache/utils/__init__.py b/gptcache/utils/__init__.py index ec0dae97..64ff88d1 100644 --- a/gptcache/utils/__init__.py +++ b/gptcache/utils/__init__.py @@ -15,6 +15,7 @@ "import_sql_client", "import_pydantic", "import_langchain", + "import_pillow" ] import importlib.util @@ -120,3 +121,7 @@ def import_pydantic(): def import_langchain(): _check_library("langchain") + + +def import_pillow(): + _check_library("pillow") diff --git a/gptcache/utils/response.py b/gptcache/utils/response.py index 648542dd..4d44a21f 100644 --- a/gptcache/utils/response.py +++ b/gptcache/utils/response.py @@ -1,3 +1,6 @@ +import base64 +import requests + def get_message_from_openai_answer(openai_resp): return openai_resp["choices"][0]["message"]["content"] @@ -8,3 +11,21 @@ def get_stream_message_from_openai_answer(openai_data): def get_text_from_openai_answer(openai_resp): return openai_resp["choices"][0]["text"] + + +def get_image_from_openai_b64(openai_resp): + return openai_resp["data"][0]["b64_json"] + + +def get_image_from_openai_url(openai_resp): + url = openai_resp["data"][0]["url"] + img_content = requests.get(url).content + img_data = base64.b64encode(img_content) + return img_data + + +def get_image_from_path(openai_resp): + img_path = openai_resp["data"][0]["url"] + with open(img_path, "rb") as f: + img_data = base64.b64encode(f.read()) + return img_data diff --git a/tests/unit_tests/adapter/test_openai.py b/tests/unit_tests/adapter/test_openai.py index 6894819f..912ff350 100644 --- a/tests/unit_tests/adapter/test_openai.py +++ b/tests/unit_tests/adapter/test_openai.py @@ -3,11 +3,26 @@ get_stream_message_from_openai_answer, get_message_from_openai_answer, get_text_from_openai_answer, + get_image_from_openai_b64, + get_image_from_path, + get_image_from_openai_url ) from gptcache.adapter import openai from gptcache import cache from gptcache.processor.pre import get_prompt +import os +import base64 +import requests +from io import BytesIO +try: + from PIL import Image +except ModuleNotFoundError: + from gptcache.utils.dependency_control import prompt_install + prompt_install("pillow") + from PIL import Image + + def test_stream_openai(): cache.init() @@ -113,3 +128,74 @@ def test_completion(): ) answer_text = get_text_from_openai_answer(response) assert answer_text == expect_answer + + +def test_image_create(): + cache.init(pre_embedding_func=get_prompt) + prompt1 = "test url"# bytes + test_url = "https://raw.githubusercontent.com/zilliztech/GPTCache/dev/docs/GPTCache.png" + test_response = { + "created": 1677825464, + "data": [ + {"url": test_url} + ] + } + prompt2 = "test base64" + img_bytes = base64.b64decode(get_image_from_openai_url(test_response)) + img_file = BytesIO(img_bytes) # convert image to file-like object + img = Image.open(img_file) + img = img.resize((256, 256)) + buffered = BytesIO() + img.save(buffered, format="JPEG") + expected_img_data = base64.b64encode(buffered.getvalue()) + + ###### Return base64 ###### + with patch("openai.Image.create") as mock_create_b64: + mock_create_b64.return_value = { + "created": 1677825464, + "data": [ + {'b64_json': expected_img_data} + ] + } + + response = openai.Image.create( + prompt=prompt1, + size="256x256", + response_format="b64_json" + ) + img_returned = get_image_from_openai_b64(response) + assert img_returned == expected_img_data + + response = openai.Image.create( + prompt=prompt1, + size="256x256", + response_format="b64_json" + ) + img_returned = get_image_from_openai_b64(response) + assert img_returned == expected_img_data + + ###### Return url ###### + with patch("openai.Image.create") as mock_create_url: + mock_create_url.return_value = { + "created": 1677825464, + "data": [ + {'url': test_url} + ] + } + + response = openai.Image.create( + prompt=prompt2, + size="256x256", + response_format="url" + ) + answer_url = response["data"][0]["url"] + assert test_url == answer_url + + response = openai.Image.create( + prompt=prompt2, + size="256x256", + response_format="url" + ) + img_returned = get_image_from_path(response) + assert img_returned == expected_img_data + os.remove(response["data"][0]["url"])