In [1]:
%pip install -r requirements.txt

Collecting aiohappyeyeballs==2.4.3
  Using cached aiohappyeyeballs-2.4.3-py3-none-any.whl (14 kB)
Collecting aiohttp==3.10.8
  Using cached aiohttp-3.10.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
Collecting aiosignal==1.3.1
  Using cached aiosignal-1.3.1-py3-none-any.whl (7.6 kB)
Collecting annotated-types==0.7.0
  Using cached annotated_types-0.7.0-py3-none-any.whl (13 kB)
Collecting anyio==4.6.0
  Using cached anyio-4.6.0-py3-none-any.whl (89 kB)
Collecting arxiv==2.1.3
  Using cached arxiv-2.1.3-py3-none-any.whl (11 kB)
Collecting async-timeout==4.0.3
  Using cached async_timeout-4.0.3-py3-none-any.whl (5.7 kB)
Collecting attrs==24.2.0
  Using cached attrs-24.2.0-py3-none-any.whl (63 kB)
Collecting cachetools==5.5.0
  Using cached cachetools-5.5.0-py3-none-any.whl (9.5 kB)
Collecting certifi==2024.8.30
  Using cached certifi-2024.8.30-py3-none-any.whl (167 kB)
Collecting charset-normalizer==3.3.2
  Using cached charset_normalizer-3.3.2-cp310-cp310-manylinu

In [2]:
!mkdir data
#
! wget "https://arxiv.org/pdf/1810.04805.pdf" -O ./data/BERT_arxiv.pdf
! wget "https://arxiv.org/pdf/2005.11401" -O ./data/RAG_arxiv.pdf
! wget "https://arxiv.org/pdf/2310.11511" -O ./data/self_rag_arxiv.pdf
! wget "https://arxiv.org/pdf/2401.15884" -O ./data/crag_arxiv.pdf

--2024-10-04 14:41:44--  https://arxiv.org/pdf/1810.04805.pdf
Resolving arxiv.org (arxiv.org)... 151.101.131.42, 151.101.3.42, 151.101.67.42, ...
Connecting to arxiv.org (arxiv.org)|151.101.131.42|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: http://arxiv.org/pdf/1810.04805 [following]
--2024-10-04 14:41:44--  http://arxiv.org/pdf/1810.04805
Connecting to arxiv.org (arxiv.org)|151.101.131.42|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 775166 (757K) [application/pdf]
Saving to: ‘./data/BERT_arxiv.pdf’


2024-10-04 14:41:44 (3.82 MB/s) - ‘./data/BERT_arxiv.pdf’ saved [775166/775166]

--2024-10-04 14:41:44--  https://arxiv.org/pdf/2005.11401
Resolving arxiv.org (arxiv.org)... 151.101.195.42, 151.101.67.42, 151.101.131.42, ...
Connecting to arxiv.org (arxiv.org)|151.101.195.42|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 885323 (865K) [application/pdf]
Saving to: ‘./data/RAG_arxiv.pdf’




Import Required Dependencies

In [38]:
from llama_index.core import SimpleDirectoryReader,VectorStoreIndex,SummaryIndex
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.core import StorageContext
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.tools import FunctionTool,QueryEngineTool
from llama_index.core.vector_stores import MetadataFilters,FilterCondition
from llama_index.core.llms import MockLLM
from typing import List,Optional

In [39]:
import  nest_asyncio
nest_asyncio.apply()

Read the Documents

In [40]:
documents = SimpleDirectoryReader(input_files = ['./data/self_rag_arxiv.pdf']).load_data()
print(len(documents))
print(f"Document Metadata: {documents[0].metadata}")

30
Document Metadata: {'page_label': '1', 'file_name': 'self_rag_arxiv.pdf', 'file_path': 'data/self_rag_arxiv.pdf', 'file_type': 'application/pdf', 'file_size': 1405127, 'creation_date': '2024-10-04', 'last_modified_date': '2023-10-19'}


Split the documents into chunks/nodes

In [41]:
splitter = SentenceSplitter(chunk_size=1024,chunk_overlap=100)
nodes = splitter.get_nodes_from_documents(documents)
print(f"Length of nodes : {len(nodes)}")
print(f"get the content for node 0 :{nodes[0].get_content(metadata_mode='all')}")

Length of nodes : 43
get the content for node 0 :page_label: 1
file_name: self_rag_arxiv.pdf
file_path: data/self_rag_arxiv.pdf
file_type: application/pdf
file_size: 1405127
creation_date: 2024-10-04
last_modified_date: 2023-10-19

Preprint.
SELF-RAG: LEARNING TO RETRIEVE , GENERATE ,AND
CRITIQUE THROUGH SELF-REFLECTION
Akari Asai†, Zeqiu Wu†, Yizhong Wang†§, Avirup Sil‡, Hannaneh Hajishirzi†§
†University of Washington§Allen Institute for AI‡IBM Research AI
{akari,zeqiuwu,yizhongw,hannaneh }@cs.washington.edu ,avi@us.ibm.com
ABSTRACT
Despite their remarkable capabilities, large language models (LLMs) often produce
responses containing factual inaccuracies due to their sole reliance on the paramet-
ric knowledge they encapsulate. Retrieval-Augmented Generation (RAG), an ad
hoc approach that augments LMs with retrieval of relevant knowledge, decreases
such issues. However, indiscriminately retrieving and incorporating a fixed number
of retrieved passages, regardless of whether retrieval 

Instantiate the vectorstore

In [42]:
import chromadb
db = chromadb.PersistentClient(path="./chroma_db_mistral")
chroma_collection = db.get_or_create_collection("multidocument-agent")
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
storage_context = StorageContext.from_defaults(vector_store=vector_store)

Instantiate the embedding model

In [43]:
from llama_index.embeddings.fastembed import FastEmbedEmbedding
from llama_index.core import Settings

embed_model = FastEmbedEmbedding(model_name="BAAI/bge-small-en-v1.5")

Settings.embed_model = embed_model

Settings.chunk_size = 1024


Fetching 5 files: 100%|██████████| 5/5 [00:00<00:00, 16033.27it/s]


In [44]:
!source .env

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [45]:
import os 
from dotenv import load_dotenv
from llama_index.llms.mistralai import MistralAI

load_dotenv()
# mistral_api_key=  os.environ["MISTRAL_API_KEY"] 
llm = MistralAI(model="mistral-large-latest",api_key=os.getenv("MISTRAL_API_KEY"))

In [46]:
#instantiate Vectorstore

name = "BERT_arxiv"
vector_index = VectorStoreIndex(nodes,storage_context=storage_context)
vector_index.storage_context.vector_store.persist(persist_path="/content/chroma_db")
#
# Define Vectorstore Autoretrieval tool
def vector_query(query:str,page_numbers:Optional[List[str]]=None)->str:
  '''
  perform vector search over index on
  query(str): query string needs to be embedded
  page_numbers(List[str]): list of page numbers to be retrieved,
                          leave blank if we want to perform a vector search over all pages
  '''
  page_numbers = page_numbers or []
  metadata_dict = [{"key":'page_label',"value":p} for p in page_numbers]
  #
  query_engine = vector_index.as_query_engine(llm=MockLLM(),similarity_top_k =2,
                                              filters = MetadataFilters.from_dicts(metadata_dict,
                                                                                    condition=FilterCondition.OR)
                                              )
  #
  response = query_engine.query(query)
  return response
#
#llamiondex FunctionTool wraps any python function we feed it
vector_query_tool = FunctionTool.from_defaults(name=f"vector_tool_{name}",
                                              fn=vector_query)
# Prepare Summary Tool
summary_index = SummaryIndex(nodes)
# summary_query_engine = summary_index.as_query_engine(llm=MockLLM(), embed_model="local")
summary_query_engine = summary_index.as_query_engine(llm=MockLLM(), response_mode="tree_summarize",
                                                      se_async=True,)
summary_query_tool = QueryEngineTool.from_defaults(name=f"summary_tool_{name}",
                                                    query_engine=summary_query_engine,
                                                  description=("Use ONLY IF you want to get a holistic summary of the documents."
                                              "DO NOT USE if you have specified questions over the documents."))

Test the LLM

In [51]:
response = llm.predict_and_call([vector_query_tool],
                                "Summarize the content in page number 2",
                                verbose=True)
response

Helper function to generate Vectorstore Tool and Summary tool for all the documents

In [23]:
def get_doc_tools(file_path:str,name:str)->str:
  '''
  get vector query and sumnmary query tools from a document
  '''
  #load documents
  documents = SimpleDirectoryReader(input_files = [file_path]).load_data()
  print(f"length of nodes")
  splitter = SentenceSplitter(chunk_size=1024,chunk_overlap=100)
  nodes = splitter.get_nodes_from_documents(documents)
  print(f"Length of nodes : {len(nodes)}")
  #instantiate Vectorstore
  vector_index = VectorStoreIndex(nodes,storage_context=storage_context)
  vector_index.storage_context.vector_store.persist(persist_path="/content/chroma_db")

  # Define Vectorstore Autoretrieval tool
  def vector_query(query:str,page_numbers:Optional[List[str]]=None)->str:
    '''
    perform vector search over index on
    query(str): query string needs to be embedded
    page_numbers(List[str]): list of page numbers to be retrieved,
                            leave blank if we want to perform a vector search over all pages
    '''
    page_numbers = page_numbers or []
    metadata_dict = [{"key":'page_label',"value":p} for p in page_numbers]
  
    query_engine = vector_index.as_query_engine(llm=MockLLM(),similarity_top_k =2,
                                                filters = MetadataFilters.from_dicts(metadata_dict,
                                                                                     condition=FilterCondition.OR)
                                                )
  
    response = query_engine.query(query)
    return response
  
  #llamiondex FunctionTool wraps any python function we feed it
  vector_query_tool = FunctionTool.from_defaults(name=f"vector_tool_{name}",
                                                fn=vector_query)
  # Prepare Summary Tool
  summary_index = SummaryIndex(nodes)
  summary_query_engine = summary_index.as_query_engine(llm=MockLLM(),response_mode="tree_summarize",
                                                       se_async=True,)
  summary_query_tool = QueryEngineTool.from_defaults(name=f"summary_tool_{name}",
                                                     query_engine=summary_query_engine,
                                                    description=("Use ONLY IF you want to get a holistic summary of the documents."
                                                "DO NOT USE if you have specified questions over the documents."))
  return vector_query_tool,summary_query_tool


Prepare a input list with specified document names

In [24]:
import os
root_path = "data"
file_name = []
file_path = []
for file in os.listdir(root_path):
  if file.endswith(".pdf"):
    file_name.append(file.split(".")[0])
    file_path.append(os.path.join(root_path,file))
#
print(file_name)
print(file_path)

['BERT_arxiv', 'self_rag_arxiv', 'crag_arxiv', 'RAG_arxiv']
['data/BERT_arxiv.pdf', 'data/self_rag_arxiv.pdf', 'data/crag_arxiv.pdf', 'data/RAG_arxiv.pdf']


Generate the vectortool and summary tool for each documents

In [25]:
papers_to_tools_dict = {}
for name,filename in zip(file_name,file_path):
  vector_query_tool,summary_query_tool = get_doc_tools(filename,name)
  papers_to_tools_dict[name] = [vector_query_tool,summary_query_tool]

length of nodes
Length of nodes : 28
length of nodes
Length of nodes : 43
length of nodes
Length of nodes : 22
length of nodes
Length of nodes : 30


Get the tools into a flat list

In [26]:
initial_tools = [t for f in file_name for t in papers_to_tools_dict[f]]
initial_tools

[<llama_index.core.tools.function_tool.FunctionTool at 0x71c84018afb0>,
 <llama_index.core.tools.query_engine.QueryEngineTool at 0x71c84017d6f0>,
 <llama_index.core.tools.function_tool.FunctionTool at 0x71c837f39fc0>,
 <llama_index.core.tools.query_engine.QueryEngineTool at 0x71c837f38850>,
 <llama_index.core.tools.function_tool.FunctionTool at 0x71c840097730>,
 <llama_index.core.tools.query_engine.QueryEngineTool at 0x71c840097670>,
 <llama_index.core.tools.function_tool.FunctionTool at 0x71c837b76950>,
 <llama_index.core.tools.query_engine.QueryEngineTool at 0x71c837b74790>]

In [27]:
from llama_index.core import VectorStoreIndex
from llama_index.core.objects import ObjectIndex

obj_index = ObjectIndex.from_objects(initial_tools,index_cls=VectorStoreIndex)

Set up the ObjectIndex as retriever

In [28]:
obj_retriever = obj_index.as_retriever(similarity_top_k=2)
tools = obj_retriever.retrieve("compare and contrast the papers self rag and corrective rag")
#
print(tools[0].metadata)
print(tools[1].metadata)

ToolMetadata(description='Use ONLY IF you want to get a holistic summary of the documents.DO NOT USE if you have specified questions over the documents.', name='summary_tool_self_rag_arxiv', fn_schema=<class 'llama_index.core.tools.types.DefaultToolFnSchema'>, return_direct=False)
ToolMetadata(description='vector_tool_self_rag_arxiv(query: str, page_numbers: Optional[List[str]] = None) -> str\n\n    perform vector search over index on\n    query(str): query string needs to be embedded\n    page_numbers(List[str]): list of page numbers to be retrieved,\n                            leave blank if we want to perform a vector search over all pages\n    ', name='vector_tool_self_rag_arxiv', fn_schema=<class 'llama_index.core.tools.utils.vector_tool_self_rag_arxiv'>, return_direct=False)


#### Setup the RAG Agent

In [31]:
from llama_index.core.agent import FunctionCallingAgentWorker
from llama_index.core.agent import AgentRunner
#
agent_worker = FunctionCallingAgentWorker.from_tools(tool_retriever=obj_retriever,
                                                     llm=llm,
                                                     system_prompt="""You are an agent designed to answer queries over a set of given papers.
                                                     Please always use the tools provided to answer a question.Do not rely on prior knowledge.""",
                                                     verbose=True)
agent = AgentRunner(agent_worker)

### Ask Query 1

In [None]:
response = agent.query("Compare and contrast self rag and crag.")
print(str(response))

### Ask Query 2



In [None]:
response = agent.query("Summarize the paper corrective RAG.")
print(str(response))