Skip to content

Commit

Permalink
Add openai Completion adapter (#202)
Browse files Browse the repository at this point in the history
Signed-off-by: shiyu22 <shiyu.chen@zilliz.com>
  • Loading branch information
shiyu22 committed Apr 14, 2023
1 parent b3385e3 commit dd10f38
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 1 deletion.
42 changes: 41 additions & 1 deletion gptcache/adapter/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
import openai

from gptcache.adapter.adapter import adapt
from gptcache.utils.response import get_message_from_openai_answer, get_stream_message_from_openai_answer
from gptcache.utils.response import (
get_stream_message_from_openai_answer,
get_message_from_openai_answer,
get_text_from_openai_answer,
)


class ChatCompletion:
Expand Down Expand Up @@ -84,3 +88,39 @@ def construct_stream_resp_from_cache(return_message):
"object": "chat.completion.chunk",
},
]


class Completion:
"""Openai Completion Wrapper"""

@classmethod
def create(cls, *args, **kwargs):
def llm_handler(*llm_args, **llm_kwargs):
return openai.Completion.create(*llm_args, **llm_kwargs)

def cache_data_convert(cache_data):
return construct_text_from_cache(cache_data)

def update_cache_callback(llm_data, update_cache_func):
update_cache_func(get_text_from_openai_answer(llm_data))
return llm_data

return adapt(
llm_handler, cache_data_convert, update_cache_callback, *args, **kwargs
)


def construct_text_from_cache(return_text):
return {
"gptcache": True,
"choices": [
{
"text": return_text,
"finish_reason": "stop",
"index": 0,
}
],
"created": int(time.time()),
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
"object": "text_completion",
}
4 changes: 4 additions & 0 deletions gptcache/utils/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@ def get_message_from_openai_answer(openai_resp):

def get_stream_message_from_openai_answer(openai_data):
return openai_data["choices"][0]["delta"].get("content", "")


def get_text_from_openai_answer(openai_resp):
return openai_resp["choices"][0]["text"]
35 changes: 35 additions & 0 deletions tests/unit_tests/adapter/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
from gptcache.utils.response import (
get_stream_message_from_openai_answer,
get_message_from_openai_answer,
get_text_from_openai_answer,
)
from gptcache.adapter import openai
from gptcache import cache
from gptcache.processor.pre import get_prompt


def test_stream_openai():
Expand Down Expand Up @@ -78,3 +80,36 @@ def test_stream_openai():
)
answer_text = get_message_from_openai_answer(response)
assert answer_text == expect_answer, answer_text


def test_completion():
cache.init(pre_embedding_func=get_prompt)
question = "what is your name?"
expect_answer = "gptcache"

with patch("openai.Completion.create") as mock_create:
mock_create.return_value = {
"choices": [
{"text": expect_answer,
"finish_reason": None,
"index": 0}
],
"created": 1677825464,
"id": "cmpl-6ptKyqKOGXZT6iQnqiXAH8adNLUzD",
"model": "text-davinci-003",
"object": "text_completion",
}

response = openai.Completion.create(
model="text-davinci-003",
prompt=question
)
answer_text = get_text_from_openai_answer(response)
assert answer_text == expect_answer

response = openai.Completion.create(
model="text-davinci-003",
prompt=question
)
answer_text = get_text_from_openai_answer(response)
assert answer_text == expect_answer

0 comments on commit dd10f38

Please sign in to comment.