Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add openai.Image.create in adapter #208

Merged
merged 1 commit into from
Apr 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions examples/openai_examples/create_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import time
import base64
from io import BytesIO
from PIL import Image

from gptcache.adapter import openai
from gptcache.processor.pre import get_prompt
from gptcache import cache

from gptcache.embedding import Onnx
from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation
from gptcache.manager import get_data_manager, CacheBase, VectorBase


onnx = Onnx()
cache_base = CacheBase('sqlite')
vector_base = VectorBase('milvus', host='localhost', port='19530', collection_name='gptcache_image', dimension=onnx.dimension)
data_manager = get_data_manager(cache_base, vector_base)

cache.init(
pre_embedding_func=get_prompt,
embedding_func=onnx.to_embeddings,
data_manager=data_manager,
similarity_evaluation=SearchDistanceEvaluation(),
)
cache.set_openai_key()

##################### Create an image with prompt1 ###################
prompt1 = 'a cat sitting besides a dog'
size1 = '256x256'

start = time.time()
response1 = openai.Image.create(
prompt=prompt1,
size=size1,
# response_format='b64_json'
response_format='b64_json'
)
end = time.time()
print('Time elapsed:', end - start)

# img = Image.open(response['data'][0]['url'])
# print(img.size)

img_b64_1 = response1['data'][0]['b64_json']
img_bytes_1 = base64.b64decode((img_b64_1))
img_file_1 = BytesIO(img_bytes_1) # convert image to file-like object
img_1 = Image.open(img_file_1) # convert image to PIL
assert img_1.size == tuple([int(x) for x in size1.split('x')]), \
'Expected to generate an image of size {size1} but got {img_1.size}.'


##################### Create an image with prompt2 #####################
prompt2 = 'a dog sitting besides a cat'
size2 = '512x512'

start = time.time()
response2 = openai.Image.create(
prompt=prompt2,
size=size2,
# response_format='b64_json'
response_format='b64_json'
)
end = time.time()
print('Time elapsed:', end - start)

# img = Image.open(response['data'][0]['url'])
# print(img.size)

img_b64_2 = response2['data'][0]['b64_json']
img_bytes_2 = base64.b64decode((img_b64_2))
img_file_2 = BytesIO(img_bytes_2) # convert image to file-like object
img_2 = Image.open(img_file_2) # convert image to PIL
assert img_2.size == tuple([int(x) for x in size2.split('x')]), \
f'Expected to generate an image of size {size2} but got {img_2.size}.'
72 changes: 72 additions & 0 deletions gptcache/adapter/openai.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,27 @@
import time
from typing import Iterator

import base64
from io import BytesIO
import os

import openai


from gptcache import CacheError
from gptcache.adapter.adapter import adapt
from gptcache.utils.response import (
get_stream_message_from_openai_answer,
get_message_from_openai_answer,
get_text_from_openai_answer,
get_image_from_openai_b64,
get_image_from_openai_url
)
from gptcache.utils import import_pillow

import_pillow()

from PIL import Image as PILImage # pylint: disable=C0413


class ChatCompletion(openai.ChatCompletion):
Expand Down Expand Up @@ -53,6 +65,36 @@ def cache_data_convert(cache_data):
**kwargs
)

class Image(openai.Image):
"""Openai Image Wrapper"""

@classmethod
def create(cls, *args, **kwargs):
def llm_handler(*llm_args, **llm_kwargs):
try:
return openai.Image.create(*llm_args, **llm_kwargs)
except Exception as e:
raise CacheError("openai error") from e

def cache_data_convert(cache_data):
return construct_image_create_resp_from_cache(
image_data=cache_data,
response_format=kwargs["response_format"],
size=kwargs["size"]
)

def update_cache_callback(llm_data, update_cache_func):
if kwargs["response_format"] == "b64_json":
update_cache_func(get_image_from_openai_b64(llm_data))
return llm_data
elif kwargs["response_format"] == "url":
update_cache_func(get_image_from_openai_url(llm_data))
return llm_data

return adapt(
llm_handler, cache_data_convert, update_cache_callback, *args, **kwargs
)


def construct_resp_from_cache(return_message):
return {
Expand Down Expand Up @@ -141,3 +183,33 @@ def construct_text_from_cache(return_text):
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
"object": "text_completion",
}


def construct_image_create_resp_from_cache(image_data, response_format, size):
img_bytes = base64.b64decode((image_data))
img_file = BytesIO(img_bytes) # convert image to file-like object
img = PILImage.open(img_file)
new_size = tuple(int(a) for a in size.split("x"))
if new_size != img.size:
img = img.resize(new_size)
buffered = BytesIO()
img.save(buffered, format="JPEG")
else:
buffered = img_file

if response_format == "url":
target_url = os.path.abspath(str(int(time.time())) + ".jpeg")
with open(target_url, "wb") as f:
f.write(buffered.getvalue())
image_data = target_url
elif response_format == "b64_json":
image_data = base64.b64encode(buffered.getvalue())
else:
raise AttributeError(f"Invalid response_format: {response_format} is not one of ['url', 'b64_json']")

return {
"created": int(time.time()),
"data": [
{response_format: image_data}
]
}
5 changes: 5 additions & 0 deletions gptcache/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"import_sql_client",
"import_pydantic",
"import_langchain",
"import_pillow"
]

import importlib.util
Expand Down Expand Up @@ -120,3 +121,7 @@ def import_pydantic():

def import_langchain():
_check_library("langchain")


def import_pillow():
_check_library("pillow")
21 changes: 21 additions & 0 deletions gptcache/utils/response.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import base64
import requests

def get_message_from_openai_answer(openai_resp):
return openai_resp["choices"][0]["message"]["content"]

Expand All @@ -8,3 +11,21 @@ def get_stream_message_from_openai_answer(openai_data):

def get_text_from_openai_answer(openai_resp):
return openai_resp["choices"][0]["text"]


def get_image_from_openai_b64(openai_resp):
return openai_resp["data"][0]["b64_json"]


def get_image_from_openai_url(openai_resp):
url = openai_resp["data"][0]["url"]
img_content = requests.get(url).content
img_data = base64.b64encode(img_content)
return img_data


def get_image_from_path(openai_resp):
img_path = openai_resp["data"][0]["url"]
with open(img_path, "rb") as f:
img_data = base64.b64encode(f.read())
return img_data
86 changes: 86 additions & 0 deletions tests/unit_tests/adapter/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,26 @@
get_stream_message_from_openai_answer,
get_message_from_openai_answer,
get_text_from_openai_answer,
get_image_from_openai_b64,
get_image_from_path,
get_image_from_openai_url
)
from gptcache.adapter import openai
from gptcache import cache
from gptcache.processor.pre import get_prompt

import os
import base64
import requests
from io import BytesIO
try:
from PIL import Image
except ModuleNotFoundError:
from gptcache.utils.dependency_control import prompt_install
prompt_install("pillow")
from PIL import Image



def test_stream_openai():
cache.init()
Expand Down Expand Up @@ -113,3 +128,74 @@ def test_completion():
)
answer_text = get_text_from_openai_answer(response)
assert answer_text == expect_answer


def test_image_create():
cache.init(pre_embedding_func=get_prompt)
prompt1 = "test url"# bytes
test_url = "https://raw.githubusercontent.com/zilliztech/GPTCache/dev/docs/GPTCache.png"
test_response = {
"created": 1677825464,
"data": [
{"url": test_url}
]
}
prompt2 = "test base64"
img_bytes = base64.b64decode(get_image_from_openai_url(test_response))
img_file = BytesIO(img_bytes) # convert image to file-like object
img = Image.open(img_file)
img = img.resize((256, 256))
buffered = BytesIO()
img.save(buffered, format="JPEG")
expected_img_data = base64.b64encode(buffered.getvalue())

###### Return base64 ######
with patch("openai.Image.create") as mock_create_b64:
mock_create_b64.return_value = {
"created": 1677825464,
"data": [
{'b64_json': expected_img_data}
]
}

response = openai.Image.create(
prompt=prompt1,
size="256x256",
response_format="b64_json"
)
img_returned = get_image_from_openai_b64(response)
assert img_returned == expected_img_data

response = openai.Image.create(
prompt=prompt1,
size="256x256",
response_format="b64_json"
)
img_returned = get_image_from_openai_b64(response)
assert img_returned == expected_img_data

###### Return url ######
with patch("openai.Image.create") as mock_create_url:
mock_create_url.return_value = {
"created": 1677825464,
"data": [
{'url': test_url}
]
}

response = openai.Image.create(
prompt=prompt2,
size="256x256",
response_format="url"
)
answer_url = response["data"][0]["url"]
assert test_url == answer_url

response = openai.Image.create(
prompt=prompt2,
size="256x256",
response_format="url"
)
img_returned = get_image_from_path(response)
assert img_returned == expected_img_data
os.remove(response["data"][0]["url"])