Skip to content

Commit

Permalink
Add openai.Image.create in adapter
Browse files Browse the repository at this point in the history
Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
  • Loading branch information
jaelgu committed Apr 14, 2023
1 parent 3cc350e commit 0d72c80
Show file tree
Hide file tree
Showing 5 changed files with 259 additions and 0 deletions.
75 changes: 75 additions & 0 deletions examples/openai_examples/create_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import time
import base64
from io import BytesIO
from PIL import Image

from gptcache.adapter import openai
from gptcache.processor.pre import get_prompt
from gptcache import cache

from gptcache.embedding import Onnx
from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation
from gptcache.manager import get_data_manager, CacheBase, VectorBase


onnx = Onnx()
cache_base = CacheBase('sqlite')
vector_base = VectorBase('milvus', host='localhost', port='19530', collection_name='gptcache_image', dimension=onnx.dimension)
data_manager = get_data_manager(cache_base, vector_base)

cache.init(
pre_embedding_func=get_prompt,
embedding_func=onnx.to_embeddings,
data_manager=data_manager,
similarity_evaluation=SearchDistanceEvaluation(),
)
cache.set_openai_key()

##################### Create an image with prompt1 ###################
prompt1 = 'a cat sitting besides a dog'
size1 = '256x256'

start = time.time()
response1 = openai.Image.create(
prompt=prompt1,
size=size1,
# response_format='b64_json'
response_format='b64_json'
)
end = time.time()
print('Time elapsed:', end - start)

# img = Image.open(response['data'][0]['url'])
# print(img.size)

img_b64_1 = response1['data'][0]['b64_json']
img_bytes_1 = base64.b64decode((img_b64_1))
img_file_1 = BytesIO(img_bytes_1) # convert image to file-like object
img_1 = Image.open(img_file_1) # convert image to PIL
assert img_1.size == tuple([int(x) for x in size1.split('x')]), \
'Expected to generate an image of size {size1} but got {img_1.size}.'


##################### Create an image with prompt2 #####################
prompt2 = 'a dog sitting besides a cat'
size2 = '512x512'

start = time.time()
response2 = openai.Image.create(
prompt=prompt2,
size=size2,
# response_format='b64_json'
response_format='b64_json'
)
end = time.time()
print('Time elapsed:', end - start)

# img = Image.open(response['data'][0]['url'])
# print(img.size)

img_b64_2 = response2['data'][0]['b64_json']
img_bytes_2 = base64.b64decode((img_b64_2))
img_file_2 = BytesIO(img_bytes_2) # convert image to file-like object
img_2 = Image.open(img_file_2) # convert image to PIL
assert img_2.size == tuple([int(x) for x in size2.split('x')]), \
f'Expected to generate an image of size {size2} but got {img_2.size}.'
72 changes: 72 additions & 0 deletions gptcache/adapter/openai.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,27 @@
import time
from typing import Iterator

import base64
from io import BytesIO
import os

import openai


from gptcache import CacheError
from gptcache.adapter.adapter import adapt
from gptcache.utils.response import (
get_stream_message_from_openai_answer,
get_message_from_openai_answer,
get_text_from_openai_answer,
get_image_from_openai_b64,
get_image_from_openai_url
)
from gptcache.utils import import_pillow

import_pillow()

from PIL import Image as PILImage # pylint: disable=C0413


class ChatCompletion(openai.ChatCompletion):
Expand Down Expand Up @@ -53,6 +65,36 @@ def cache_data_convert(cache_data):
**kwargs
)

class Image(openai.Image):
"""Openai Image Wrapper"""

@classmethod
def create(cls, *args, **kwargs):
def llm_handler(*llm_args, **llm_kwargs):
try:
return openai.Image.create(*llm_args, **llm_kwargs)
except Exception as e:
raise CacheError("openai error") from e

def cache_data_convert(cache_data):
return construct_image_create_resp_from_cache(
image_data=cache_data,
response_format=kwargs["response_format"],
size=kwargs["size"]
)

def update_cache_callback(llm_data, update_cache_func):
if kwargs["response_format"] == "b64_json":
update_cache_func(get_image_from_openai_b64(llm_data))
return llm_data
elif kwargs["response_format"] == "url":
update_cache_func(get_image_from_openai_url(llm_data))
return llm_data

return adapt(
llm_handler, cache_data_convert, update_cache_callback, *args, **kwargs
)


def construct_resp_from_cache(return_message):
return {
Expand Down Expand Up @@ -141,3 +183,33 @@ def construct_text_from_cache(return_text):
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
"object": "text_completion",
}


def construct_image_create_resp_from_cache(image_data, response_format, size):
img_bytes = base64.b64decode((image_data))
img_file = BytesIO(img_bytes) # convert image to file-like object
img = PILImage.open(img_file)
new_size = tuple(int(a) for a in size.split("x"))
if new_size != img.size:
img = img.resize(new_size)
buffered = BytesIO()
img.save(buffered, format="JPEG")
else:
buffered = img_file

if response_format == "url":
target_url = os.path.abspath(str(int(time.time())) + ".jpeg")
with open(target_url, "wb") as f:
f.write(buffered.getvalue())
image_data = target_url
elif response_format == "b64_json":
image_data = base64.b64encode(buffered.getvalue())
else:
raise AttributeError(f"Invalid response_format: {response_format} is not one of ['url', 'b64_json']")

return {
"created": int(time.time()),
"data": [
{response_format: image_data}
]
}
5 changes: 5 additions & 0 deletions gptcache/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"import_sql_client",
"import_pydantic",
"import_langchain",
"import_pillow"
]

import importlib.util
Expand Down Expand Up @@ -120,3 +121,7 @@ def import_pydantic():

def import_langchain():
_check_library("langchain")


def import_pillow():
_check_library("pillow")
21 changes: 21 additions & 0 deletions gptcache/utils/response.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import base64
import requests

def get_message_from_openai_answer(openai_resp):
return openai_resp["choices"][0]["message"]["content"]

Expand All @@ -8,3 +11,21 @@ def get_stream_message_from_openai_answer(openai_data):

def get_text_from_openai_answer(openai_resp):
return openai_resp["choices"][0]["text"]


def get_image_from_openai_b64(openai_resp):
return openai_resp["data"][0]["b64_json"]


def get_image_from_openai_url(openai_resp):
url = openai_resp["data"][0]["url"]
img_content = requests.get(url).content
img_data = base64.b64encode(img_content)
return img_data


def get_image_from_path(openai_resp):
img_path = openai_resp["data"][0]["url"]
with open(img_path, "rb") as f:
img_data = base64.b64encode(f.read())
return img_data
86 changes: 86 additions & 0 deletions tests/unit_tests/adapter/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,26 @@
get_stream_message_from_openai_answer,
get_message_from_openai_answer,
get_text_from_openai_answer,
get_image_from_openai_b64,
get_image_from_path,
get_image_from_openai_url
)
from gptcache.adapter import openai
from gptcache import cache
from gptcache.processor.pre import get_prompt

import os
import base64
import requests
from io import BytesIO
try:
from PIL import Image
except ModuleNotFoundError:
from gptcache.utils.dependency_control import prompt_install
prompt_install("pillow")
from PIL import Image



def test_stream_openai():
cache.init()
Expand Down Expand Up @@ -113,3 +128,74 @@ def test_completion():
)
answer_text = get_text_from_openai_answer(response)
assert answer_text == expect_answer


def test_image_create():
cache.init(pre_embedding_func=get_prompt)
prompt1 = "test url"# bytes
test_url = "https://raw.githubusercontent.com/zilliztech/GPTCache/dev/docs/GPTCache.png"
test_response = {
"created": 1677825464,
"data": [
{"url": test_url}
]
}
prompt2 = "test base64"
img_bytes = base64.b64decode(get_image_from_openai_url(test_response))
img_file = BytesIO(img_bytes) # convert image to file-like object
img = Image.open(img_file)
img = img.resize((256, 256))
buffered = BytesIO()
img.save(buffered, format="JPEG")
expected_img_data = base64.b64encode(buffered.getvalue())

###### Return base64 ######
with patch("openai.Image.create") as mock_create_b64:
mock_create_b64.return_value = {
"created": 1677825464,
"data": [
{'b64_json': expected_img_data}
]
}

response = openai.Image.create(
prompt=prompt1,
size="256x256",
response_format="b64_json"
)
img_returned = get_image_from_openai_b64(response)
assert img_returned == expected_img_data

response = openai.Image.create(
prompt=prompt1,
size="256x256",
response_format="b64_json"
)
img_returned = get_image_from_openai_b64(response)
assert img_returned == expected_img_data

###### Return url ######
with patch("openai.Image.create") as mock_create_url:
mock_create_url.return_value = {
"created": 1677825464,
"data": [
{'url': test_url}
]
}

response = openai.Image.create(
prompt=prompt2,
size="256x256",
response_format="url"
)
answer_url = response["data"][0]["url"]
assert test_url == answer_url

response = openai.Image.create(
prompt=prompt2,
size="256x256",
response_format="url"
)
img_returned = get_image_from_path(response)
assert img_returned == expected_img_data
os.remove(response["data"][0]["url"])

0 comments on commit 0d72c80

Please sign in to comment.