## MultiModal RAG Agent with Tavily

#### RAG Data Ingestion

In [48]:
from dotenv import load_dotenv
load_dotenv()
from unstructured.partition.pdf import partition_pdf
import os
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.output_parsers import StrOutputParser
import uuid
from langchain_astradb import AstraDBVectorStore
from langchain.storage import InMemoryStore
from langchain_core.documents import Document
from langchain.embeddings import OpenAIEmbeddings
from langchain.retrievers.multi_vector import MultiVectorRetriever
from pydantic import BaseModel, Field
from typing import Annotated, TypedDict, Literal
from langchain_core.messages import BaseMessage
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import  AnyMessage
import operator
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain.output_parsers import PydanticOutputParser
import warnings
warnings.filterwarnings('ignore')

In [4]:
## Document Loading and Chunking

file_path = 'content//attention.pdf'

def partition_and_chunk(file_path):
    """Partitions and chunks a PDF file into manageable pieces. """
    
    chunks = partition_pdf(
        filename=file_path,
        strategy='hi_res',
        infer_table_structure=True,
        
        extract_image_block_types=['Image'],
        extract_image_block_to_payload=True,
        
        chunking_strategy='by_title',
        max_characters = 2000,
        combine_text_under_n_chars = 500,
        new_after_n_chars = 6000,
    )
    
    return chunks

In [5]:
chunks = partition_and_chunk(file_path)

The `max_size` parameter is deprecated and will be removed in v4.26. Please specify in `size['longest_edge'] instead`.


In [6]:
## Extract text
texts = [chunk for chunk in chunks if 'CompositeElement' in str(type(chunk))]
## getting base64 object of image
images = [ el.metadata.image_base64 for chunk in chunks for el in chunk.metadata.orig_elements if 'Image' in str(type(el)) ]

In [7]:
prompt_text = '''
You are an assistant tasked with summarizing tables and text.
Give a concise summary of the table or text.

Respond only with the summary, no additional comment.
Do not start your message by saying "Here is a summary" or anything similar.
just give the summary as it is.

Table or text chunk: {element}
'''

prompt_image = """Describe the image in detail. Be specific about the architecture, graphs, plots such as bar plot"""
image_message = [
    (
        "user",
        [
         {'type':'text', 'text':prompt_image},
         {'type':'image_url','image_url':{'url':'data:image/jpeg;base64,{image}'},}, 
        ]
    )
]

llm = ChatOpenAI(model = 'gpt-4o-mini')

In [8]:
## get summary of text and images
def summary_for_vs(chunks, texts = texts, images = images):
    
    image_prompt_template = ChatPromptTemplate.from_messages(image_message)
    image_chain = image_prompt_template | llm | StrOutputParser()
    
    text_prompy_template = ChatPromptTemplate.from_template(prompt_text)
    text_chain = text_prompy_template | llm | StrOutputParser()
    
    text_summary = text_chain.batch(texts)
    image_summary = image_chain.batch(images)
    
    return text_summary,image_summary
    

In [9]:
text_summary, image_summary = summary_for_vs(chunks)

In [12]:
token = os.getenv('ASTRA_DB_APPLICATION_TOKEN')
namespace = os.getenv('ASTRA_DB_KEYSPACE')
endpoint = os.getenv('ASTRA_DB_API_ENDPOINT')
embedding = OpenAIEmbeddings(model="text-embedding-3-small")

In [13]:
## The vector store to index the summary chunks
vector_store = AstraDBVectorStore(
    embedding=embedding,
    collection_name="RAG_Graph",
    api_endpoint=endpoint,
    token=token,
    namespace=namespace,
)

## The storage layer for the parent documents
store = InMemoryStore()
id_key = "doc_id"

## The retriever (empty to start)
retriever = MultiVectorRetriever(
    vectorstore=vector_store,  ## The vector store to index the summary chunks
    docstore=store,            ## The storage layer for the parent documents
    id_key=id_key,
)


In [14]:
## adding image summries
def loading_summaries_to_vector_store(retriever, chunks, chunk_summary):
    """
        Generate ids for each chunk, create langchain document object for each summry chunk.
        Indexing the summary in vector store and document in docsotre.
    """
    ## generate unique id for each chunk
    doc_ids = [str(uuid.uuid4()) for _ in chunks]
    ## Creating Langchain Document objects for each text_summary chunk
    summary_texts = [Document(page_content=summary,metadata={id_key:doc_ids[i]}) for i,summary in enumerate(chunk_summary)]
    
    ## indexing the documents in vector store and document store
    retriever.vectorstore.add_documents(summary_texts)
    retriever.docstore.mset(list(zip(doc_ids,chunks)))

In [20]:
## adding text summaries to vector store and document store
loading_summaries_to_vector_store(retriever,texts,text_summary)
## adding image summaries to vector store and document store
loading_summaries_to_vector_store(retriever,images,image_summary)
## Now the retriever is ready to use

## Agent

In [34]:
## Creating Pydantic class

class RouterCall_Output(BaseModel):
    Topic: Literal['Related','Not Related'] = Field(description="Classification of the user query")
    Reasoning: str = Field(description='Reasoning behind topic selection')
    
parser = PydanticOutputParser(pydantic_object=RouterCall_Output)

In [49]:
## State class
class State(TypedDict):
    messages : Annotated[list[AnyMessage], operator.add]

In [None]:
def supervisor(state:State):
    
    question = state['messages']
    print(question)
    
    template = '''
    You are an expert in routing. Your task is to classify if the given user query is related to LLM Transformer architecture.
    You must classify as either "Related" or "Not Related".
    Provide a brief reasoning for your classification.
    
    user query : {question}
    {format_instructions}
    '''
    
    prompt = PromptTemplate(
        template=template,
        input_variables=['question'],
        partial_variables={'format_instructions':parser.get_format_instructions()}
    )
    
    chain = prompt | llm | parser
    response = chain.invoke({'question':question})
    print("Parsed response:", response)
    
    return {'messages':[response.Topic]}

In [54]:
## Router function
def router(state: State):
    print('-->Router-->')
    
    last_message = state['messages'][-1]
    print("last_message: ",last_message)
    
    if 'Related' in last_message.lower():
        return 'RAG Call'
    
    return 'LLM Call'

In [None]:
## RAG Function


In [53]:
supervisor({'messages':'what are Transformer'})

{'messages': 'what are Transformer'}
Parsed response: Topic='Related' Reasoning="The query 'what are Transformer' is directly asking about the Transformer architecture, which is a foundational concept in LLM (Large Language Model) architectures, particularly in natural language processing."


{'messages': ['Related']}