diff --git a/bootstraprag/templates/llamaindex/rag_with_self_correction/self_correction_core.py b/bootstraprag/templates/llamaindex/rag_with_self_correction/self_correction_core.py index 68709bf..f42e458 100644 --- a/bootstraprag/templates/llamaindex/rag_with_self_correction/self_correction_core.py +++ b/bootstraprag/templates/llamaindex/rag_with_self_correction/self_correction_core.py @@ -29,7 +29,8 @@ class SelfCorrectingRAG: ] def __init__(self, input_dir: str, similarity_top_k: int = 3, chunk_size: int = 128, - chunk_overlap: int = 100, show_progress: bool = False, no_of_retries: int = 5, required_exts: list[str] = ['.pdf', '.txt']): + chunk_overlap: int = 100, show_progress: bool = False, no_of_retries: int = 5, + required_exts: list[str] = ['.pdf', '.txt']): self.input_dir = input_dir self.similarity_top_k = similarity_top_k @@ -108,7 +109,8 @@ def query_with_retry_query_engine(self, query: str) -> RESPONSE_TYPE: # source nodes for the query based on llm node evaluation. def query_with_source_query_engine(self, query: str) -> RESPONSE_TYPE: retry_source_query_engine = RetrySourceQueryEngine(self.base_query_engine, - self.query_response_evaluator) + self.query_response_evaluator, + max_retries=self.no_of_retries) retry_source_response = retry_source_query_engine.query(query) return retry_source_response @@ -121,6 +123,7 @@ def query_with_guideline_query_engine(self, query: str) -> RESPONSE_TYPE: "The response should try to summarize where possible.\n" ) # just for example retry_guideline_query_engine = RetryGuidelineQueryEngine(self.base_query_engine, - guideline_eval, resynthesize_query=True) + guideline_eval, resynthesize_query=True, + max_retries=self.no_of_retries) retry_guideline_response = retry_guideline_query_engine.query(query) return retry_guideline_response diff --git a/bootstraprag/templates/llamaindex/rag_with_self_correction_with_observability/self_correction_core.py b/bootstraprag/templates/llamaindex/rag_with_self_correction_with_observability/self_correction_core.py index 3706c26..d269e94 100644 --- a/bootstraprag/templates/llamaindex/rag_with_self_correction_with_observability/self_correction_core.py +++ b/bootstraprag/templates/llamaindex/rag_with_self_correction_with_observability/self_correction_core.py @@ -7,6 +7,7 @@ from llama_index.llms.ollama import Ollama from llama_index.embeddings.ollama import OllamaEmbedding from llama_index.vector_stores.qdrant import QdrantVectorStore +from llama_index. core. base. base_query_engine import BaseQueryEngine from llama_index.core.base.response.schema import Response, StreamingResponse, AsyncStreamingResponse, PydanticResponse from llama_index.core.query_engine import RetryQueryEngine, RetrySourceQueryEngine, RetryGuidelineQueryEngine from llama_index.core.evaluation import RelevancyEvaluator, GuidelineEvaluator @@ -68,8 +69,8 @@ def __init__(self, input_dir: str, similarity_top_k: int = 3, chunk_size: int = api_key=os.environ['DB_API_KEY']) self.vector_store = QdrantVectorStore(client=self.client, collection_name=os.environ['COLLECTION_NAME']) self.query_response_evaluator = RelevancyEvaluator() - self.base_query_engine = None - self._index = None + self.base_query_engine: BaseQueryEngine = None + self._index: VectorStoreIndex = None self._load_data_and_create_engine()