<a href="https://colab.research.google.com/github/ua-datalab/Generative-AI/blob/main/Intro_to_RAG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%pip install -qU langchain_core langchain_openai chroma sentence-transformers docling langchain-text-splitters "langchain-chroma>=0.1.2" langchain_huggingface https://gradio-builds.s3.amazonaws.com/a0c487cd57a217775f0d1bc77c041b7cd516cc8a/gradio-3.41.2-py3-none-any.whl

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.3/10.3 MB[0m [31m23.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m48.2/48.2 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[?25h

# Introduction to Retrieval Augmented Generation (RAG)

__Authored by__: Enrique Noriega-Atala
__Last edited__: 10/07/2024

Retrieval Augmented Generation is a method to prompt the LLM to elicit responses based on information retrieved from a data base, hence _retrieval_.

RAG is a useful method to reduce LLM hallucinations and to _interface_ with documents in a more natural way, akin to a conversation rather than a search engine.

## Step 1: Document Parsing

In [None]:
from docling.document_converter import DocumentConverter

source = "https://arxiv.org/pdf/1706.03762"  # PDF path or URL
converter = DocumentConverter()
result = converter.convert_single(source)
markdown = result.render_as_markdown()

In [None]:
from IPython.display import display, Markdown, Latex

display(Markdown(markdown))

## Step 2: Split the document

In [None]:
from langchain_text_splitters import MarkdownHeaderTextSplitter

headers_to_split_on = [
    ("#", "Header 1"),
    ("##", "Header 2"),
]

# MD splits
markdown_splitter = MarkdownHeaderTextSplitter(
    headers_to_split_on=headers_to_split_on, strip_headers=False
)
md_header_splits = markdown_splitter.split_text(markdown)

# Char-level splits
from langchain_text_splitters import RecursiveCharacterTextSplitter

chunk_size = 250
chunk_overlap = 30
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=chunk_size, chunk_overlap=chunk_overlap
)

# Split
splits = text_splitter.split_documents(md_header_splits)
splits

In [None]:
print(f'Total number of "chunks": {len(splits)}')

## Step 3: Encode the splits into a _Vector Database_

In order to execute _semantic_ queries, we need to generate _vector representations_ of each split "chunk".

We will do this using `sentence-transformers` and `ChromaDB`

In [None]:
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings

embeddings = HuggingFaceEmbeddings()

vector_store = Chroma(
    embedding_function=embeddings,
)

# Index the documents using the embeddings model
vector_store.add_documents(splits)

In [None]:
# You can do retrieval using sentence similarity
retriever = vector_store.as_retriever()

retriever.invoke("Attention is")

In [None]:
# We can configure the parameters of the search, such as the number of chunks returned
retriever = vector_store.as_retriever(search_kwargs={'k':10})
retriever.invoke("Attention is")

## Step 4: Connect the retrieval to an LLM

We will use Verde for this purpose

In [None]:
import os
from google.colab import userdata
os.environ['OPENAI_API_KEY'] = userdata.get('OPENAI_API_KEY')

In [None]:
from langchain_openai import ChatOpenAI

API_ENDPOINT = "https://llm1.cyverse.ai/v1"

llm = ChatOpenAI(model="Mistral-7B-Instruct-v0.3", base_url=API_ENDPOINT)

llm.invoke("Hello! What can you tell me about yourself?")

Once we have an llm client, we can wire togethe the retrieval component and the llm. First, let's try chaining the retriever to a prompt template, to get context aware prompts

In [None]:
from langchain_core.runnables import RunnablePassthrough
from langchain_core.prompts import ChatPromptTemplate
from pprint import pprint

conversation_template = ChatPromptTemplate([
    ("system", """You will look at the following passages and reply to the question using only and only the information present in the following documents.
         If the information requested is not present in the documents or if the question can not be answered using the documents as context, just say that you can't answer the question
         ```{context}```
         """),
    ("human", """Question: ```{question}```
    Answer: """)
])

chain = {"context": retriever, "question": RunnablePassthrough()} | conversation_template

for message in chain.invoke("What is all I need?").messages:
  pprint((str(type(message)), message.content))

Observe how the documents have a lot of irrelevant metadata.
We need to format them better using a helper function

In [None]:
from langchain_core.runnables import RunnableLambda

def format_message(msgs):
  return '\n\n'.join([f"Section: {next(iter(msg.metadata.values()))}\n Contents: {msg.page_content}" for msg in msgs])

chain = {"context": retriever | RunnableLambda(format_message), "question": RunnablePassthrough()} | conversation_template

for message in chain.invoke("What is all I need?").messages:
  pprint((str(type(message)), message.content))

We can even implement filters such as filtering out the `References`

In [None]:
def format_message(msgs):
  return '\n\n'.join([f"Section: {next(iter(msg.metadata.values()))}\n Contents: {msg.page_content}" for msg in msgs if next(iter(msg.metadata.values())) != "References"])

chain = {"context": retriever | RunnableLambda(format_message), "question": RunnablePassthrough()} | conversation_template

for message in chain.invoke("What is this paper about?").messages:
  pprint((str(type(message)), message.content))

Now, present the conversation to the LLM and see what it responds

In [None]:
from langchain_core.output_parsers import StrOutputParser

chain = {"context": retriever | RunnableLambda(format_message), "question": RunnablePassthrough()} | conversation_template | llm | StrOutputParser()

print(chain.invoke("What is this paper about?"))

We can customize the prompt to get additional details, such as quoting the source of the information from which it drew the conclusion from

In [None]:
conversation_template = ChatPromptTemplate([
    ("system", """You will look at the following passages and reply to the question using only and only the information present in the following documents.
         You will not refer to the information, just write the answer. After you give the answer, cite the section from which you read it.
         If the information requested is not present in the documents or if the question can not be answered using the documents as context, just say that you can't answer the question
         ```{context}```
         """),
    ("human", """Question: ```{question}```
    Answer: """)
])

chain = {"context": retriever | RunnableLambda(format_message), "question": RunnablePassthrough()} | conversation_template | llm | StrOutputParser()

print(chain.invoke("What is this paper about?"))

In [None]:
pprint(chain.invoke("Give me the summary of this work"))

In [None]:
pprint(chain.invoke("What are the elements of the transformer architecture?"))

## Optional: Build a Gradio interface for our RAG pipeline

In [None]:
%load_ext gradio

In [None]:
%%blocks

import gradio as gr

# We need a function that calls our langchain pipeline

def rag(question):
  return chain.invoke(question)

demo = gr.Interface(
    fn=rag,             # Here, we are wiring the function to the interface
    inputs=["text"],      # Specify the input types
    outputs=["text"],     # Same, for output
)

There you go.
Each RAG application has its own details and nuances. There is no good "one size fits all" solution. Instead, it depends heavily in my design choices.

This notebook provides a good starting point to implement your own RAG pipeline.