Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: GPT cache llama index integration #554

Open
MeghaWalia-eco opened this issue Oct 19, 2023 · 6 comments
Open

[Bug]: GPT cache llama index integration #554

MeghaWalia-eco opened this issue Oct 19, 2023 · 6 comments

Comments

@MeghaWalia-eco
Copy link

Current Behavior

I am trying to integrate GPTCache with llama index but LLM predictor is not accepting cache argument , to fix this i have created a cacheLLMPredictor class extended from LLM Predictor

from typing import Any
from llama_index import BasePromptTemplate
from llama_index.llm_predictor.base import LLMPredictor
from pydantic import BaseModel

from llama_index.llm_predictor.base import LLMPredictor
from pydantic import BaseModel

class CachedLLMPredictor(LLMPredictor):
    cache: Any  # Define the cache attribute

    class Config(BaseModel.Config):
        extra = 'allow'  # Allow extra attributes

    def __init__(self, cache, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.cache = cache

    def predict(self, prompt: BasePromptTemplate, **prompt_args: Any) -> str:
        cache_key = (prompt, tuple(sorted(prompt_args.items())))
        if cache_key in self.cache:
            return self.cache[cache_key]
        else:
            result = super().predict(prompt, **prompt_args)
            self.cache[cache_key] = result
            return result

But here self.cache[cache_key] = result and return self.cache[cache_key] lines are throwing errors and it is not working.

My actual problem is i have to add GPTCache to the existing LLamaIndex calls , my existing implementation is as below

query_engine = self.__llama_idx_svc.get_query_engine(tenant_id,
                                                             tenant_index,
                                                             tenant_config,
                                                             model_name,
                                                             node_postprocessors=node_postprocessors,
                                                             text_qa_template=text_qa_template,
                                                             synthesizer_mode=synthesizer_mode)

        if query_engine is None:
            raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Unable to construct Query Engine")

        response = query_engine.query(query)_qa_template,
                                            synthesizer_mode=synthesizer_mode)

        if query_engine is None:
            raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Unable to construct Query Engine")

        response = query_engine.query(query)
def __get_vector_store_qe(self,
                              tenant_index: Index,
                              tenant_config: Config,
                              model_name: Optional[str] = None,
                              node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
                              text_qa_template: Optional[BasePromptTemplate] = None,
                              synthesizer_mode: Optional[str] = None) -> BaseQueryEngine:
        # Load Index
        loaded_index = self.__index_svc.load_index(tenant_index, tenant_config, model_name)

        # Get Context & LLM
        service_ctx = loaded_index.service_context
        callback_manager = service_ctx.callback_manager

        # Get config
        related_series_prompt = tenant_config.config.get("qna_related_series_prompt", QNA_DEFAULT_RELATED_SERIES_PROMPT )

        # Configure prompt template
        text_qa_template = text_qa_template or None

        if str(tenant_index.name) == 'series_metadata': # TODO - Hardcoded, consider to move to Index Config
            text_qa_template = Prompt(
                related_series_prompt[0],
                prompt_type=PromptType.QUESTION_ANSWER
            )

        # postprocessing setup
        node_postprocessors = node_postprocessors or []

        retriever = VectorIndexRetriever(
            index=loaded_index, 
            similarity_top_k=20,
        )

        # Configure response synthesizer
        synthesizer_mode = synthesizer_mode or "compact"

        response_synthesizer = get_response_synthesizer(
            service_context=service_ctx,
            callback_manager=callback_manager,
            text_qa_template=text_qa_template,
            response_mode=synthesizer_mode,
        )

        # Assemble query engine
        return RetrieverQueryEngine(retriever=retriever,
                                    response_synthesizer=response_synthesizer,
                                    callback_manager=callback_manager,
                                    node_postprocessors=node_postprocessors)

def load_index(self,
tenant_index: Index,
tenant_config: Config,
model_name: Optional[str] = None):

    if tenant_index.type == 'sql_store' or tenant_index.type == 'sql_store_with_meta':
        return self.__load_sql_index(tenant_index, tenant_config, model_name)
    else:
        return self.__load_vector_index(tenant_index, tenant_config, model_name)     
    
def get_content_func(data, **_):
    return data.get("prompt").split("Question")[-1]   

# TODO - Move to LlamaIndexSvc
def __load_vector_index(self,
                        tenant_index: Index,
                        tenant_config: Config,
                        model_name: Optional[str] = None):
    
    gptcache_obj = GPTCache(self.init_gptcache)
    
    docstore = MongoDocumentStore.from_uri(db_name=os.getenv("INDEX_DB_NAME"),
                                           uri=os.getenv("INDEX_MONGODB_URL"),
                                           namespace=tenant_index.docstore_namespace)

    vector_store = PGVectorStore.from_params(database=os.getenv("INDEX_DB_NAME"),
                                             host=self.__postgres_host,
                                             port=self.__postgres_port,
                                             user=self.__postgres_user,
                                             password=self.__postgres_pass,
                                             table_name=tenant_index.vector_store_table,
                                             embed_dim=tenant_index.embed_dim,
                                             hybrid_search=tenant_index.hybrid_search,
                                             text_search_config=tenant_index.text_search_config)

    storage_ctx = StorageContext.from_defaults(index_store=self.__index_store,
                                               docstore=docstore,
                                               vector_store=vector_store)

    # get config
    llm_temperature = tenant_config.config.get("llm_temperature", 0)
    llm_num_outputs = tenant_config.config.get("llm_num_outputs", None)
    llm_api_key = tenant_config.config.get("llm_api_key", "")
    llm_model = tenant_config.config.get("llm_model", "")
    if model_name and model_name in MODEL_NAME_MAP:
        llm_model = MODEL_NAME_MAP[model_name] 

    node_parser = SimpleNodeParser(
        text_splitter=TokenTextSplitter(
            chunk_size=tenant_index.chunk_size,
            chunk_overlap=tenant_index.max_chunk_overlap,
            callback_manager=self.__get_llm_callback_manager(),
        ),
    )

    embed_model = OpenAIEmbedding(api_key=llm_api_key)
    llm = OpenAI(temperature=llm_temperature,
                 max_tokens=llm_num_outputs,
                 model=llm_model,
                 api_key=llm_api_key)
    
    service_ctx = ServiceContext.from_defaults(llm_predictor=CachedLLMPredictor(llm=llm, cache=gptcache_obj),
                                               embed_model=embed_model,
                                               node_parser=node_parser,
                                               callback_manager=self.__get_llm_callback_manager())

    return load_index_from_storage(index_id=str(tenant_index.id),
                                   storage_context=storage_ctx,
                                   service_context=service_ctx)

Expected Behavior

need ti implement gpt caching in llm calls

Steps To Reproduce

above code

Environment

No response

Anything else?

No response

@SimFG
Copy link
Collaborator

SimFG commented Oct 19, 2023

You only need to deal with pydantic's check of attributes in the class, and naturally you can use GPTCache. Or you can build an openai proxy service and use GPTCache in the service.

@MeghaWalia-eco
Copy link
Author

I am not getting any pydantic error but when i am trying to set or retrieve the cache key, I am getting errors But here self.cache[cache_key] = result and return self.cache[cache_key] lines are throwing errors and it is not working.

Can i get an example using above code on how to do that

@SimFG
Copy link
Collaborator

SimFG commented Oct 19, 2023

The error is ?

@MeghaWalia-eco
Copy link
Author

File "C:\AILatestClone\EconomistDigitalSolutions\openai-hack\app\service\CachedLLMPredictor.py", line 22, in predict
if cache_key in self.cache:
^^^^^^^^^^^^^^^^^^^^^^^
TypeError: argument of type 'GPTCache' is not iterable

File "C:\AILatestClone\EconomistDigitalSolutions\openai-hack\app\service\CachedLLMPredictor.py", line 26, in predict
self.cache[cache_key] = result
~~~~~~~~~~^^^^^^^^^^^
TypeError: 'GPTCache' object does not support item assignment

I think i am not accessing the cache correctly

@SachinGanesh
Copy link

@MeghaWalia-eco

Were you able to solve this issue?

@MeghaWalia-eco
Copy link
Author

@SachinGanesh No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants