Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support the pinecone vector store #485

Open
wants to merge 3 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions gptcache/adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg
chat_cache = kwargs.pop("cache_obj", cache)
session = kwargs.pop("session", None)
require_object_store = kwargs.pop("require_object_store", False)
# metadata = kwargs.pop("metadata", {})
if require_object_store:
assert chat_cache.data_manager.o, "Object store is required for adapter."
if not chat_cache.has_init:
Expand Down Expand Up @@ -91,6 +92,7 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg
top_k=kwargs.pop("top_k", 5)
if (user_temperature and not user_top_k)
else kwargs.pop("top_k", -1),
**kwargs,
)
if search_data_list is None:
search_data_list = []
Expand Down Expand Up @@ -245,7 +247,7 @@ def post_process():
if cache_enable:
try:

def update_cache_func(handled_llm_data, question=None):
def update_cache_func(handled_llm_data, question=None, **kwargs):
if question is None:
question = pre_store_data
else:
Expand All @@ -260,14 +262,14 @@ def update_cache_func(handled_llm_data, question=None):
embedding_data,
extra_param=context.get("save_func", None),
session=session,
**kwargs
)
if (
chat_cache.report.op_save.count > 0
and chat_cache.report.op_save.count % chat_cache.config.auto_flush
== 0
):
chat_cache.flush()

llm_data = update_cache_callback(
llm_data, update_cache_func, *args, **kwargs
)
Expand Down Expand Up @@ -359,6 +361,7 @@ async def aadapt(
top_k=kwargs.pop("top_k", 5)
if (user_temperature and not user_top_k)
else kwargs.pop("top_k", -1),
**kwargs
)
if search_data_list is None:
search_data_list = []
Expand Down
4 changes: 3 additions & 1 deletion gptcache/adapter/langchain_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,9 @@ def generate(
callbacks: Callbacks = None,
**kwargs,
) -> LLMResult:
print("kwargs inside generate: ",kwargs)
self.tmp_args = kwargs
return super().generate(messages, stop=stop, callbacks=callbacks)
return super().generate(messages, stop=stop, callbacks=callbacks, **kwargs)

async def agenerate(
self,
Expand All @@ -232,6 +233,7 @@ def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
return self.chat.get_num_tokens_from_messages(messages)

def __call__(self, messages: Any, stop: Optional[List[str]] = None, **kwargs):
print("kwargs in __call__: ", kwargs)
generation = self.generate([messages], stop=stop, **kwargs).generations[0][0]
if isinstance(generation, ChatGeneration):
return generation.message
Expand Down
5 changes: 3 additions & 2 deletions gptcache/adapter/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class ChatCompletion(openai.ChatCompletion, BaseCacheLLM):
@classmethod
def _llm_handler(cls, *llm_args, **llm_kwargs):
try:
_ = llm_kwargs.pop('metadata',{})
return super().create(*llm_args, **llm_kwargs) if cls.llm is None else cls.llm(*llm_args, **llm_kwargs)
except openai.OpenAIError as e:
raise wrap_error(e) from e
Expand All @@ -66,7 +67,7 @@ def _update_cache_callback(
): # pylint: disable=unused-argument
if not isinstance(llm_data, Iterator):
update_cache_func(
Answer(get_message_from_openai_answer(llm_data), DataType.STR)
Answer(get_message_from_openai_answer(llm_data), DataType.STR), **kwargs
)
return llm_data
else:
Expand All @@ -76,7 +77,7 @@ def hook_openai_data(it):
for item in it:
total_answer += get_stream_message_from_openai_answer(item)
yield item
update_cache_func(Answer(total_answer, DataType.STR))
update_cache_func(Answer(total_answer, DataType.STR), **kwargs)

return hook_openai_data(llm_data)

Expand Down
3 changes: 2 additions & 1 deletion gptcache/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def close():
if not os.getenv("IS_CI"):
gptcache_log.error(e)

def import_data(self, questions: List[Any], answers: List[Any], session_ids: Optional[List[Optional[str]]] = None) -> None:
def import_data(self, questions: List[Any], answers: List[Any], session_ids: Optional[List[Optional[str]]] = None, **kwargs) -> None:
"""Import data to GPTCache

:param questions: preprocessed question Data
Expand All @@ -101,6 +101,7 @@ def import_data(self, questions: List[Any], answers: List[Any], session_ids: Opt
answers=answers,
embedding_datas=[self.embedding_func(question) for question in questions],
session_ids=session_ids if session_ids else [None for _ in range(len(questions))],
**kwargs
)

def flush(self):
Expand Down
12 changes: 8 additions & 4 deletions gptcache/manager/data_manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from importlib.metadata import metadata
import pickle
from abc import abstractmethod, ABCMeta
from typing import List, Any, Optional, Union
Expand Down Expand Up @@ -35,6 +36,7 @@ def import_data(
answers: List[Any],
embedding_datas: List[Any],
session_ids: List[Optional[str]],
**kwargs
):
pass

Expand Down Expand Up @@ -135,6 +137,7 @@ def import_data(
answers: List[Any],
embedding_datas: List[Any],
session_ids: List[Optional[str]],
**kwargs
):
if (
len(questions) != len(answers)
Expand Down Expand Up @@ -271,7 +274,7 @@ def save(self, question, answer, embedding_data, **kwargs):
"""
session = kwargs.get("session", None)
session_id = session.name if session else None
self.import_data([question], [answer], [embedding_data], [session_id])
self.import_data([question], [answer], [embedding_data], [session_id], **kwargs)

def _process_answer_data(self, answers: Union[Answer, List[Answer]]):
if isinstance(answers, Answer):
Expand Down Expand Up @@ -302,6 +305,7 @@ def import_data(
answers: List[Answer],
embedding_datas: List[Any],
session_ids: List[Optional[str]],
**kwargs,
):
if (
len(questions) != len(answers)
Expand Down Expand Up @@ -332,7 +336,7 @@ def import_data(
[
VectorData(id=ids[i], data=embedding_data)
for i, embedding_data in enumerate(embedding_datas)
]
], kwargs=kwargs
)
self.eviction_base.put(ids)

Expand Down Expand Up @@ -367,8 +371,8 @@ def hit_cache_callback(self, res_data, **kwargs):

def search(self, embedding_data, **kwargs):
embedding_data = normalize(embedding_data)
top_k = kwargs.get("top_k", -1)
return self.v.search(data=embedding_data, top_k=top_k)
top_k = kwargs.pop("top_k", -1)
return self.v.search(data=embedding_data, top_k=top_k, **kwargs)

def flush(self):
self.s.flush()
Expand Down
6 changes: 4 additions & 2 deletions gptcache/manager/vector_data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,19 @@
class VectorData:
id: int
data: np.ndarray
account_id: int = '-1'
pipeline: str = ''


class VectorBase(ABC):
"""VectorBase: base vector store interface"""

@abstractmethod
def mul_add(self, datas: List[VectorData]):
def mul_add(self, datas: List[VectorData], **kwargs):
pass

@abstractmethod
def search(self, data: np.ndarray, top_k: int):
def search(self, data: np.ndarray, top_k: int, **kwargs):
pass

@abstractmethod
Expand Down
4 changes: 2 additions & 2 deletions gptcache/manager/vector_data/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ def __init__(
self._persist_directory = persist_directory
self._collection = self._client.get_or_create_collection(name=collection_name)

def mul_add(self, datas: List[VectorData]):
def mul_add(self, datas: List[VectorData], **kwargs):
data_array, id_array = map(list, zip(*((data.data.tolist(), str(data.id)) for data in datas)))
self._collection.add(embeddings=data_array, ids=id_array)

def search(self, data, top_k: int = -1):
def search(self, data, top_k: int = -1, **kwargs):
if self._collection.count() == 0:
return []
if top_k == -1:
Expand Down
4 changes: 2 additions & 2 deletions gptcache/manager/vector_data/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ def __init__(self, index_file_path, dimension, top_k):
if os.path.isfile(index_file_path):
self._index = faiss.read_index(index_file_path)

def mul_add(self, datas: List[VectorData]):
def mul_add(self, datas: List[VectorData], **kwargs):
data_array, id_array = map(list, zip(*((data.data, data.id) for data in datas)))
np_data = np.array(data_array).astype("float32")
ids = np.array(id_array)
self._index.add_with_ids(np_data, ids)

def search(self, data: np.ndarray, top_k: int = -1):
def search(self, data: np.ndarray, top_k: int = -1, **kwargs):
if self._index.ntotal == 0:
return None
if top_k == -1:
Expand Down
11 changes: 11 additions & 0 deletions gptcache/manager/vector_data/manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from gptcache.utils.error import NotFoundError, ParamError
import pinecone

TOP_K = 1

Expand Down Expand Up @@ -257,6 +258,16 @@ def get(name, **kwargs):
flush_interval_sec=flush_interval_sec,
index_params=index_params,
)
elif name == "pinecone":
from gptcache.manager.vector_data.pinecone import Pinecone
api_key = kwargs.get("api_key", None)
metric = kwargs.get("metric",'cosine')
environment = kwargs.get("environment",None)
dimension = kwargs.get("dimension", DIMENSION)
top_k: int = kwargs.get("top_k", TOP_K)
index_name = kwargs.get("index_name", "caching")
pinecone.init(api_key=api_key, environment=environment)
vector_base = Pinecone(index_file_path=index_name,dimension=dimension,top_k=top_k, metric=metric)
else:
raise NotFoundError("vector store", name)
return vector_base
64 changes: 64 additions & 0 deletions gptcache/manager/vector_data/pinecone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from importlib.metadata import metadata
import os
from typing import List
from xml.etree.ElementInclude import include
import numpy as np
from gptcache.manager.vector_data.base import VectorBase, VectorData
import pinecone
import time

class Pinecone(VectorBase):
"""vector store: Pinecone

:param index_path: the path to Pinecone index, defaults to 'caching'.
:type index_path: str
:param dimension: the dimension of the vector, defaults to 0.
:type dimension: int
:param top_k: the number of the vectors results to return, defaults to 1.
:type top_k: int
"""

def __init__(self, index_file_path, dimension, top_k, metric):
self._index_file_path = index_file_path
self._dimension = dimension
assert metric=='euclidean'
self.indexes = pinecone.list_indexes()
if index_file_path not in self.indexes:
pinecone.create_index(index_file_path, dimension=dimension, metric=metric)
time.sleep(50)
self.index = pinecone.Index(index_file_path)
self._top_k = top_k

def mul_add(self, datas: List[VectorData], **kwargs):
metadata = kwargs.get('kwargs').get('kwargs').pop('metadata',{})
assert metadata!={}, "Please provide the metadata for the following request to process!!"
data_array, id_array= map(list, zip(*((data.data, data.id) for data in datas)))
np_data = np.array(data_array).astype("float32")
ids = np.array(id_array)
upsert_data = [(str(i_d), data.reshape(1,-1).tolist(), {"account_id": int(metadata['account_id']), "pipeline": str(metadata['pipeline'])}) for (i_d,data) in zip(ids,np_data)]
self.index.upsert(upsert_data)

def search(self, data: np.ndarray, top_k: int = -1, **kwargs):
if self.index.describe_index_stats()['total_vector_count'] == 0:
return None
if top_k == -1:
top_k = self._top_k
metadata = kwargs.get("metadata",{})
assert metadata!={}, "Please provide metadata for the search query!!"
np_data = np.array(data).astype("float32").reshape(1, -1)
response = self.index.query(vector = np_data.tolist(), top_k = top_k, include_values = False, filter={"account_id": int(metadata["account_id"]), "pipeline": str(metadata["pipeline"])}) #add additional filter
if len(response['matches'])!=0:
dist, ids = [response['matches'][0]['score']], [int(response['matches'][0]['id'])]
return list(zip(dist, ids))
else:
return None

def rebuild(self, ids=None):
return True

def delete(self, ids):
ids_to_remove = np.array(ids)
self.index.delete(ids=ids_to_remove) # add namespace

def count(self):
return self.index.describe_index_stats()['total_vector_count']
58 changes: 58 additions & 0 deletions testing-pinecone-chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import time
from gptcache import cache
from gptcache.adapter import openai
from gptcache.embedding import Onnx
from gptcache.manager import CacheBase, VectorBase, get_data_manager
from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation
import pdb
from gptcache.adapter.langchain_models import LangChainChat
from langchain.chat_models import ChatOpenAI
from langchain.schema import (
AIMessage,
HumanMessage,
SystemMessage
)

print("Cache loading.....")

def get_msg(data, **_):
return data.get("messages")[-1].content

onnx = Onnx()
### you can uncomment the following lines according to which database you want to use
# data_manager = get_data_manager(CacheBase("sqlite"), VectorBase("pinecone", \
# dimension=onnx.dimension, index_name='caching', api_key='e0c287dd-b4a3-4600-ad42-5bf792decf19',\
# environment = 'asia-southeast1-gcp-free', metric='euclidean'))

data_manager = get_data_manager(CacheBase("sqlite"), VectorBase("faiss", dimension=onnx.dimension))
# data_manager = get_data_manager(CacheBase("sqlite"), VectorBase("chromadb", dimension=onnx.dimension))
# data_manager = get_data_manager(CacheBase("sqlite"), VectorBase("docarray", dimension=onnx.dimension))

cache.init(
pre_embedding_func=get_msg,
embedding_func=onnx.to_embeddings,
data_manager=data_manager,
similarity_evaluation=SearchDistanceEvaluation(),
)
cache.set_openai_key()

chat = LangChainChat(chat=ChatOpenAI(temperature=0.0))

questions = [
"tell me something about chatgpt",
"what is chatgpt?",
]

metadata = {
'account_id': '-123',
'pipeline': 'completion'
}

if __name__=="__main__":
# pdb.set_trace()
for question in questions:
start_time = time.time()
messages = [HumanMessage(content=question)]
print(chat(messages, metadata=metadata))
print(f'Question: {question}')
print("Time consuming: {:.2f}s".format(time.time() - start_time))
Loading
Loading