#### Adaptive Retrieval-Augmented Generation (RAG) System

This system implements an advanced Retrieval-Augmented Generation (RAG) approach that adapts its retrieval strategy based on the type of query. By leveraging Language Models (LLMs) at various stages, it aims to provide more accurate, relevant, and context-aware responses to user queries.

In [20]:
import os
import sys
from dotenv import load_dotenv
load_dotenv()
from langchain_groq import ChatGroq
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain.vectorstores import Chroma
from langchain_core.documents import Document
from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
from utility import encode_pdf, show_context, retrieve_context_per_question
from langchain_core.output_parsers import StrOutputParser
from typing import List, Any, Dict
from concurrent.futures import ThreadPoolExecutor, as_completed
from langchain_community.docstore.in_memory import InMemoryDocstore
from tqdm import tqdm
from langchain.vectorstores import Chroma, FAISS
import faiss
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader, TextLoader
from utility import replace_t_with_space
from langchain_experimental.text_splitter import SemanticChunker
import pymupdf
from pydantic import BaseModel, Field

#### Define the query classifier

In [21]:
#SChema for llm
class Categories(BaseModel):
    category:str = Field(
        description="The category of the query. The options are Factual, Analytical, Opinion or Contextual",
        examples="Factual")

class QueryClassifier:
    def __init__(self):
        groq_api_key=os.getenv("GROQ_API_KEY")
        self.llm=ChatGroq(groq_api_key=groq_api_key,model_name="llama-3.1-8b-instant")
        self.prompt = PromptTemplate(
            template="""Classify the following query into one of these categories: 
            Factual, Analytical, Opinion, or Contextual.
            Query: {query}
            Category:""", input_variables=["query"]
        )
        self.classifier_chain = self.prompt | self.llm.with_structured_output(Categories)

    def classify(self,query):
        print("Classifying query")
        response = self.classifier_chain.invoke({"query":query})
        return response.category


##### Basic retrival strategy

In [22]:
class BaseRetrievalStrategy:
    def __init__(self,texts):
        self.embeddings=HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=1000,
            chunk_overlap=100,
            length_function=len
        )
        self.docs = text_splitter.create_documents(texts)
        self.vectorstore = FAISS.from_documents(self.docs,self.embeddings)
        groq_api_key=os.getenv("GROQ_API_KEY")
        self.llm=ChatGroq(groq_api_key=groq_api_key,model_name="llama-3.1-8b-instant")

    def retriever(self,query,k=3):
        return self.vectorstore.similarity_search(query,k)

### Define Factual retriever strategy

In [23]:
#Define schema for llm
class Relevance(BaseModel):
    relevance_score:float = Field(
        description="The relevance score of the document to a query",
        examples=9.0
    )

class FactualRetrievalStrategy(BaseRetrievalStrategy):
    def retrieve(self,query,k=3):
        print("retriving factual")

        #Use LLM to enchance the query.
        prompt = PromptTemplate(
            template = "Enhance this factual query for better information retrieval. {query}",
            input_variables=["query"]
        )
        enhance_chain = prompt | self.llm
        enhance_query = enhance_chain.invoke({"query":query})
        print("Enhanced Query : ",enhance_query)

        #Retrieve documents as per enhance query
        docs = self.vectorstore.similarity_search(enhance_query,k=k*2)

        #User LLm to rank the relevance of documents to the query
        relevance_prompt = PromptTemplate(
            template = """On a scale of 1-10, how relevant is this document to the query:
            {query} ?
            document : {document}
            Relevance score : """,input_variables=["query","document"]
        ) 

        relevance_chain = relevance_prompt | self.llm.with_structured_output(Relevance)
        print("Generating relevance score...")

        ranked_docs = []
        for doc in docs:
            score = float(relevance_chain.invoke(
                {"query":enhance_query,"document":doc.page_content}).relevance_score)
            ranked_docs.append((doc,score))
        
        #sort by relevance score and return top k docuemnts
        ranked_docs.sort(key=lambda x:x[1],reverse=True)

        return [doc for doc,_ in ranked_docs[:k]]
    

#### define Analytical retriever strategy

In [24]:
class SelectedIndices(BaseModel):
    indices:List[int] = Field(
        description="Indices of the selected documents",
        examples=[0,1,2,3]
    )

class SubQuery(BaseModel):
    sub_queries:List[str] = Field(
        description="List of sub-queries for comprehensive analysis",
        examples=["What is the population of New York?", "What is the GDP of New York?"]
    )

class AnalyticalRetrievalStrategy(BaseRetrievalStrategy):
    def retrieve(self,query,k=3):
        print("Retriving Analytical")

        #Use LLM to generate sub-queries for comprehensive analysis
        sub_query_prompt = PromptTemplate(
            template = """ Generate {k} sub-queries for query {query}""",
            input_variables=["k","query"]
        )
        sub_query_chain = sub_query_prompt | self.llm.with_structured_output(SubQuery)

        sub_queries = sub_query_chain.invoke({"k":k,"query":query}).sub_queries

        print(f"Sub Queries for comprehensive analysis: {sub_queries}")

        all_docs = []
        for query in sub_queries:
            all_docs.extend(self.vectorstore.similarity_search(query,k=2))
        
        #Use LLM to ensure diversity and relevance
        diversity_prompt = PromptTemplate(
            template = """ Select the most diverse and relevant set of {k} documents for the 
            query: '{query}'\n
            Documents: {docs}\n
            Return only the indices of selected documents as a list of integers."""
        )

        diversity_chain = diversity_prompt | self.llm.with_structured_output(SelectedIndices)

        docs_text = "\n".join([f"{i} : {doc.page_content}" for i,doc in enumerate(all_docs)])
        selected_indices = diversity_chain.invoke({"query":query,"docs":docs_text,"k":k}).indices

        print("Selected diverse and relevant documents")
        return [all_docs[i] for i in selected_indices if i < len(all_docs)]

#### Define Opinion Retriever Strategy

In [25]:
class OpinionRetrievalStrategy(BaseRetrievalStrategy):
    def retrieve(self, query, k=3):
        print("retrieving opinion")
        # Use LLM to identify potential viewpoints
        viewpoints_prompt = PromptTemplate(
            input_variables=["query", "k"],
            template="Identify {k} distinct viewpoints or perspectives on the topic: {query}"
        )
        viewpoints_chain = viewpoints_prompt | self.llm
        input_data = {"query": query, "k": k}
        viewpoints = viewpoints_chain.invoke(input_data).content.split('\n')
        print(f'viewpoints: {viewpoints}')

        all_docs = []
        for viewpoint in viewpoints:
            all_docs.extend(self.db.similarity_search(f"{query} {viewpoint}", k=2))

        # Use LLM to classify and select diverse opinions
        opinion_prompt = PromptTemplate(
            input_variables=["query", "docs", "k"],
            template="Classify these documents into distinct opinions on '{query}' and select the {k} most representative and diverse viewpoints:\nDocuments: {docs}\nSelected indices:"
        )
        opinion_chain = opinion_prompt | self.llm.with_structured_output(SelectedIndices)
        
        docs_text = "\n".join([f"{i}: {doc.page_content[:100]}..." for i, doc in enumerate(all_docs)])
        input_data = {"query": query, "docs": docs_text, "k": k}
        selected_indices = opinion_chain.invoke(input_data).indices
        print(f'selected diverse and relevant documents')
        
        return [all_docs[int(i)] for i in selected_indices.split() if i.isdigit() and int(i) < len(all_docs)]

#### Define Contextual Retriever Strategy

In [26]:
class ContextualRetrievalStrategy(BaseRetrievalStrategy):
    def retrieve(self, query, k=4, user_context=None):
        print("retrieving contextual")
        # Use LLM to incorporate user context into the query
        context_prompt = PromptTemplate(
            input_variables=["query", "context"],
            template="Given the user context: {context}\nReformulate the query to best address the user's needs: {query}"
        )
        context_chain = context_prompt | self.llm
        input_data = {"query": query, "context": user_context or "No specific context provided"}
        contextualized_query = context_chain.invoke(input_data).content
        print(f'contextualized query: {contextualized_query}')

        # Retrieve documents using the contextualized query
        docs = self.db.similarity_search(contextualized_query, k=k*2)

        # Use LLM to rank the relevance of retrieved documents considering the user context
        ranking_prompt = PromptTemplate(
            input_variables=["query", "context", "doc"],
            template="Given the query: '{query}' and user context: '{context}', rate the relevance of this document on a scale of 1-10:\nDocument: {doc}\nRelevance score:"
        )
        ranking_chain = ranking_prompt | self.llm.with_structured_output(Relevance)
        print("ranking docs")

        ranked_docs = []
        for doc in docs:
            input_data = {"query": contextualized_query, "context": user_context or "No specific context provided", "doc": doc.page_content}
            score = float(ranking_chain.invoke(input_data).relevance_score)
            ranked_docs.append((doc, score))


        # Sort by relevance score and return top k
        ranked_docs.sort(key=lambda x: x[1], reverse=True)

        return [doc for doc, _ in ranked_docs[:k]]

#### Define the adaptive retriever

In [27]:
class AdaptiveRetriever:
    def __init__(self,texts:List[str]):
        self.classifier = QueryClassifier()
        self.strategies = {
            "Factual": FactualRetrievalStrategy(texts),
            "Analytical": AnalyticalRetrievalStrategy(texts),
            "Opinion": OpinionRetrievalStrategy(texts),
            "Contextual": ContextualRetrievalStrategy(texts)
        }

    def get_relevant_documents(self,query:str) -> List[Document]:
        category = self.classifier.classify(query)
        strategy = self.strategies[category]
        return strategy.retrieve(query)

#### Define additional custom retriever that inherit from langchain retriever

In [28]:
from langchain_core.retrievers import BaseRetriever
class PydanticAdaptiveRetriever(BaseRetriever):
    adaptive_retriever: AdaptiveRetriever = Field(exclude=True)

    class Config:
        arbitrary_types_allowed = True

    def _get_relevant_documents(self, query: str) -> List[Document]:
        return self.adaptive_retriever.get_relevant_documents(query)

    async def _aget_relevant_documents(self, query: str) -> List[Document]:
        return self.get_relevant_documents(query)

#### Adaptive RAG class

In [29]:
class AdaptiveRAG:
    def __init__(self, texts: List[str]):
        adaptive_retriever = AdaptiveRetriever(texts)
        self.retriever = PydanticAdaptiveRetriever(adaptive_retriever=adaptive_retriever)
        groq_api_key=os.getenv("GROQ_API_KEY")
        self.llm=ChatGroq(groq_api_key=groq_api_key,model_name="llama-3.1-8b-instant")
        
        # Create a custom prompt
        prompt_template = """Use the following pieces of context to answer the question at the end. 
        If you don't know the answer, just say that you don't know, don't try to make up an answer.

        {context}

        Question: {question}
        Answer:"""
        prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
        
        # Create the LLM chain
        self.llm_chain = prompt | self.llm
        
      

    def answer(self, query: str) -> str:
        docs = self.retriever.invoke(query)
        input_data = {"context": "\n".join([doc.page_content for doc in docs]), "question": query}
        return self.llm_chain.invoke(input_data)

#### Demo of the ADaptive RAG

In [30]:
# Usage
texts = [
    "The Earth is the third planet from the Sun and the only astronomical object known to harbor life."
    ]
rag_system = AdaptiveRAG(texts)

##### Showcase the four different types of queries

In [None]:
factual_result = rag_system.answer("What is the distance between the Earth and the Sun?").content
print(f"Answer: {factual_result}")

analytical_result = rag_system.answer("How does the Earth's distance from the Sun affect its climate?").content
print(f"Answer: {analytical_result}")

opinion_result = rag_system.answer("What are the different theories about the origin of life on Earth?").content
print(f"Answer: {opinion_result}")

contextual_result = rag_system.answer("How does the Earth's position in the Solar System influence its habitability?").content
print(f"Answer: {contextual_result}")