Skip to content

Commit

Permalink
Add object storage (#213)
Browse files Browse the repository at this point in the history
Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
  • Loading branch information
junjiejiangjjj committed Apr 16, 2023
1 parent 39faf3a commit 211949a
Show file tree
Hide file tree
Showing 19 changed files with 412 additions and 51 deletions.
9 changes: 4 additions & 5 deletions gptcache/adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,22 +49,21 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg
)
if ret is None:
continue
cache_question, cache_answer, cache_embedding = ret
rank = chat_cache.similarity_evaluation.evaluation(
{
"question": pre_embedding_data,
"embedding": embedding_data,
},
{
"question": cache_question,
"answer": cache_answer,
"question": ret.question,
"answer": ret.answers[0].answer,
"search_result": cache_data,
"embedding": cache_embedding
"embedding": ret.embedding_data
},
extra_param=context.get("evaluation_func", None),
)
if rank_threshold <= rank:
cache_answers.append((rank, cache_answer))
cache_answers.append((rank, ret.answers[0].answer))
chat_cache.data_manager.update_access_time(cache_data)
cache_answers = sorted(cache_answers, key=lambda x: x[0], reverse=True)
if len(cache_answers) != 0:
Expand Down
3 changes: 2 additions & 1 deletion gptcache/adapter/langchain_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from gptcache.adapter.adapter import adapt
from gptcache.utils import import_pydantic, import_langchain
from gptcache.manager.scalar_data.base import Answer, AnswerType

import_pydantic()
import_langchain()
Expand Down Expand Up @@ -107,7 +108,7 @@ def cache_data_convert(cache_data):


def update_cache_callback(llm_data, update_cache_func):
update_cache_func(llm_data)
update_cache_func(Answer(llm_data, AnswerType.STR))
return llm_data


Expand Down
9 changes: 5 additions & 4 deletions gptcache/adapter/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from gptcache import CacheError
from gptcache.adapter.adapter import adapt
from gptcache.manager.scalar_data.base import Answer, AnswerType
from gptcache.utils.response import (
get_stream_message_from_openai_answer,
get_message_from_openai_answer,
Expand Down Expand Up @@ -37,7 +38,7 @@ def llm_handler(cls, *llm_args, **llm_kwargs):
@staticmethod
def update_cache_callback(llm_data, update_cache_func):
if not isinstance(llm_data, Iterator):
update_cache_func(get_message_from_openai_answer(llm_data))
update_cache_func(Answer(get_message_from_openai_answer(llm_data)), AnswerType.STR)
return llm_data
else:

Expand All @@ -46,7 +47,7 @@ def hook_openai_data(it):
for item in it:
total_answer += get_stream_message_from_openai_answer(item)
yield item
update_cache_func(total_answer)
update_cache_func(Answer(total_answer, AnswerType.STR))

return hook_openai_data(llm_data)

Expand Down Expand Up @@ -85,10 +86,10 @@ def cache_data_convert(cache_data):

def update_cache_callback(llm_data, update_cache_func):
if kwargs["response_format"] == "b64_json":
update_cache_func(get_image_from_openai_b64(llm_data))
update_cache_func(Answer(get_image_from_openai_b64(llm_data), AnswerType.IMAGE_BASE64))
return llm_data
elif kwargs["response_format"] == "url":
update_cache_func(get_image_from_openai_url(llm_data))
update_cache_func(Answer(get_image_from_openai_url(llm_data), AnswerType.IMAGE_URL))
return llm_data

return adapt(
Expand Down
1 change: 1 addition & 0 deletions gptcache/manager/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from gptcache.manager.scalar_data import CacheBase
from gptcache.manager.vector_data import VectorBase
from gptcache.manager.object_data import ObjectBase
from gptcache.manager.factory import get_data_manager
50 changes: 33 additions & 17 deletions gptcache/manager/data_manager.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from abc import abstractmethod, ABCMeta
import pickle
from typing import List, Any
from typing import List, Any, Optional, Union

import cachetools
import numpy as np

from gptcache.utils.error import CacheError, ParamError
from gptcache.manager.scalar_data.base import CacheStorage, CacheData
from gptcache.manager.scalar_data.base import CacheStorage, CacheData, AnswerType, Answer
from gptcache.manager.vector_data.base import VectorBase, VectorData
from gptcache.manager.object_data.base import ObjectBase
from gptcache.manager.eviction import EvictionManager


Expand All @@ -24,9 +25,8 @@ def import_data(
):
pass

# should return the tuple, (question, answer, embedding)
@abstractmethod
def get_scalar_data(self, res_data, **kwargs):
def get_scalar_data(self, res_data, **kwargs) -> CacheData:
pass

def update_access_time(self, res_data, **kwargs):
Expand Down Expand Up @@ -89,8 +89,8 @@ def import_data(
for i, embedding_data in enumerate(embedding_datas):
self.data[embedding_data] = (questions[i], answers[i], embedding_datas[i])

def get_scalar_data(self, res_data, **kwargs):
return res_data
def get_scalar_data(self, res_data, **kwargs) -> CacheData:
return CacheData(question=res_data[0], answers=res_data[1])

def search(self, embedding_data, **kwargs):
try:
Expand Down Expand Up @@ -127,15 +127,13 @@ class SSDataManager(DataManager):
:type eviction: str
"""

s: CacheStorage
v: VectorBase

def __init__(self, s: CacheStorage, v: VectorBase, max_size, clean_size, eviction="LRU"):
def __init__(self, s: CacheStorage, v: VectorBase, o: Optional[ObjectBase], max_size, clean_size, eviction="LRU"):
self.max_size = max_size
self.cur_size = 0
self.clean_size = clean_size
self.s = s
self.v = v
self.o = o
self.eviction = EvictionManager(self.s, self.v, eviction)
self.cur_size = self.s.count()

Expand All @@ -145,13 +143,13 @@ def _clear(self):
self.eviction.delete()
self.cur_size = self.s.count()

def save(self, question, answer, embedding_data, **kwargs):
def save(self, question, answer, embedding_data , **kwargs):
"""Save the data and vectors to cache and vector storage.
:param question: question data.
:type question: str
:param answer: answer data.
:type answer: str
:type answer: str, Answer or (Any, AnswerType)
:param embedding_data: vector data.
:type embedding_data: np.ndarray
Expand All @@ -167,11 +165,21 @@ def save(self, question, answer, embedding_data, **kwargs):

if self.cur_size >= self.max_size:
self._clear()

self.import_data([question], [answer], [embedding_data])

def _process_answer_data(self, answers: Union[Answer, List[Answer]]):
if isinstance(answers, Answer):
answers = [answers]
new_ans = []
for ans in answers:
if ans.answer_type != AnswerType.STR:
new_ans.append(Answer(self.o.put(ans.answer), ans.answer_type))
else:
new_ans.append(ans)
return new_ans

def import_data(
self, questions: List[Any], answers: List[Any], embedding_datas: List[Any]
self, questions: List[Any], answers: List[Answer], embedding_datas: List[Any]
):
if len(questions) != len(answers) or len(questions) != len(embedding_datas):
raise ParamError("Make sure that all parameters have the same length")
Expand All @@ -180,10 +188,14 @@ def import_data(
normalize(embedding_data) for embedding_data in embedding_datas
]
for i, embedding_data in enumerate(embedding_datas):
if self.o is not None:
ans = self._process_answer_data(answers[i])
else:
ans = answers[i]
cache_datas.append(
CacheData(
question=questions[i],
answer=answers[i],
answers=ans,
embedding_data=embedding_data.astype("float32"),
)
)
Expand All @@ -196,8 +208,12 @@ def import_data(
)
self.cur_size += len(questions)

def get_scalar_data(self, res_data, **kwargs):
return self.s.get_data_by_id(res_data[1])
def get_scalar_data(self, res_data, **kwargs) -> CacheData:
cache_data = self.s.get_data_by_id(res_data[1])
for ans in cache_data.answers:
if ans.answer_type != AnswerType.STR:
ans.answer = self.o.get(ans.answer)
return cache_data

def update_access_time(self, res_data, **kwargs):
return self.s.update_access_time(res_data[1])
Expand Down
11 changes: 8 additions & 3 deletions gptcache/manager/factory.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Union, Callable
from gptcache.manager.data_manager import SSDataManager, MapDataManager
from gptcache.manager import CacheBase, VectorBase
from gptcache.manager import CacheBase, VectorBase, ObjectBase


def get_data_manager(
cache_base: Union[CacheBase, str] = None,
vector_base: Union[VectorBase, str] = None,
object_base: Union[ObjectBase, str] = None,
max_size: int = 1000,
clean_size: int = None,
eviction: str = "LRU",
Expand All @@ -21,6 +22,8 @@ def get_data_manager(
:param vector_base: a VectorBase object, or the name of the vector storage, it is support 'milvus', 'faiss' and
'chromadb' now.
:type vector_base: :class:`VectorBase` or str
:param object_base: a object storage, supports local path and s3.
:type object_base: :class:`ObjectBase` or str
:param max_size: the max size for the cache, defaults to 1000.
:type max_size: int
:param clean_size: the size to clean up, defaults to `max_size * 0.2`.
Expand Down Expand Up @@ -48,6 +51,8 @@ def get_data_manager(
if isinstance(cache_base, str):
cache_base = CacheBase(name=cache_base)
if isinstance(vector_base, str):
vector_base = VectorBase(name=cache_base)
vector_base = VectorBase(name=vector_base)
if isinstance(object_base, str):
object_base = ObjectBase(name=object_base)
assert cache_base and vector_base
return SSDataManager(cache_base, vector_base, max_size, clean_size, eviction)
return SSDataManager(cache_base, vector_base, object_base, max_size, clean_size, eviction)
37 changes: 37 additions & 0 deletions gptcache/manager/object_data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
__all__ = ["ObjectBase"]

from gptcache.utils.lazy_import import LazyImport

object_manager = LazyImport(
"object_manager", globals(), "gptcache.manager.object_data.manager"
)


def ObjectBase(name: str, **kwargs):
"""Generate specific ObjectStorage with the configuration. For example, setting for
`ObjectBase` (with `name`) to manage LocalObjectStorage, S3 object storage.
:param name: the name of the object storage, it is support 'local', 's3'.
:type name: str
:param path: the cache root of the LocalObjectStorage.
:type path: str
:param bucket: the bucket of s3.
:type bucket: str
:param path_prefix: s3 object prefix.
:type path_prefix: str
:param access_key: the access_key of s3.
:type access_key: str
:param secret_key: the secret_key of s3.
:type secret_key: str
:return: ObjectStorage.
Example:
.. code-block:: python
from gptcache.manager import ObjectBase
obj_storage = ObjectBase('local', path='./')
"""
return object_manager.ObjectBase.get(name, **kwargs)
24 changes: 24 additions & 0 deletions gptcache/manager/object_data/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from abc import ABC, abstractmethod
from typing import Any, List


class ObjectBase(ABC):
"""
Object storage base.
"""

@abstractmethod
def put(self, obj: Any) -> str:
pass

@abstractmethod
def get(self, obj: str) -> Any:
pass

@abstractmethod
def get_access_link(self, obj: str) -> str:
pass

@abstractmethod
def delete(self, to_delete: List[str]):
pass
42 changes: 42 additions & 0 deletions gptcache/manager/object_data/local_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Any, List
import os
import logging
import uuid
from pathlib import Path
from gptcache.manager.object_data.base import ObjectBase

logger = logging.getLogger()


class LocalObjectStorage(ObjectBase):
"""Local object storage
"""

def __init__(self, local_root: str):
self._local_root = Path(local_root)
self._local_root.mkdir(exist_ok=True)

def put(self, obj: Any) -> str:
f_path = self._local_root / str(uuid.uuid4())
with open(f_path, "wb") as f:
f.write(obj)
return str(f_path.absolute())

def get(self, obj: str) -> Any:
try:
with open(obj, "rb") as f:
return f.read()
except Exception: # pylint: disable=broad-except
return None

def get_access_link(self, obj: str, _: int = 3600):
return obj

def delete(self, to_delete: List[str]):
assert isinstance(to_delete, list)
for obj in to_delete:
try:
os.remove(obj)
except Exception: # pylint: disable=broad-except
logger.warning("Can not find obj: %s", obj)
pass
26 changes: 26 additions & 0 deletions gptcache/manager/object_data/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from gptcache.utils.error import NotFoundStoreError


class ObjectBase:
"""
ObjectBase to manager the object storage.
"""

def __init__(self):
raise EnvironmentError(
"CacheBase is designed to be instantiated, please using the `CacheBase.get(name)`."
)

@staticmethod
def get(name, **kwargs):
if name == "local":
from gptcache.manager.object_data.local_storage import LocalObjectStorage # pylint: disable=import-outside-toplevel
object_base = LocalObjectStorage(kwargs.get("path", "./local_obj"))
elif name == "s3":
from gptcache.manager.object_data.s3_storage import S3Storage # pylint: disable=import-outside-toplevel
object_base = S3Storage(kwargs.get("path_prefix"), kwargs.get("bucket"),
kwargs.get("access_key"), kwargs.get("secret_key"),
kwargs.get("endpoint"))
else:
raise NotFoundStoreError("object store", name)
return object_base
Loading

0 comments on commit 211949a

Please sign in to comment.