# Enhancing RAG with Contextual Retrieval

We will use an LLM to generate for each chunk and document a contextual sentence to improve its retrival accuracy and use in hybrid search.

* Generate the context sentence.
* Enrich the chunk embedding vectors with the context.
* Create a topic database to be used in hybrid search.
* Perform hybrid search to improve retrieval results.

### Visual improvements

We will use [rich library](https://github.com/Textualize/rich) to make the output more readable, and supress warning messages.

In [2]:
from rich.pretty import pprint
from rich.theme import Theme
from rich.console import Console
from rich.panel import Panel
from rich.text import Text

custom_theme = Theme({
    "repr.own": "bright_yellow",            # Class names
    "repr.tag_name": "bright_yellow",       # Adjust tag names which might still be purple
    "repr.call": "bright_yellow",           # Function calls and other symbols
    "repr.str": "bright_green",             # String representation
    "repr.number": "bright_red",            # Numbers
    "repr.attrib_name": "bright_yellow",    # Attribute names
    "repr.attrib_value": "bright_blue"      # Attribute values
})

# Apply the theme and print the object with rich formatting

console = Console(theme=custom_theme)

In [3]:
import warnings

# Suppress warnings
warnings.filterwarnings('ignore')

## Loading a complex dataset of documents

We will load a complex dataset of scientific documents from Arxiv. Applying naive chunks on such documents will give poor results in RAG applications.

In [4]:
from datasets import load_dataset

dataset = load_dataset("jamescalam/ai-arxiv2", split="train")
console.print(dataset)

In [5]:
import os
from semantic_router.encoders import OpenAIEncoder

os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")

encoder = OpenAIEncoder(name="text-embedding-3-small")

In [6]:
from semantic_chunkers import StatisticalChunker
import logging

logging.disable(logging.CRITICAL)

chunker = StatisticalChunker(
    encoder=encoder,
    min_split_tokens=100,
    max_split_tokens=500,
)

In [7]:
chunks_0 = chunker(docs=[dataset["content"][0]])


In [11]:
first_chunk = ' '.join(chunks_0[0][0].splits)
console.print(first_chunk)

In [18]:
from dotenv import load_dotenv

load_dotenv()

True

In [19]:
import anthropic

client = anthropic.Anthropic(
    # This is the default and can be omitted
    api_key=os.getenv("ANTHROPIC_API_KEY"),
)


In [20]:
DOCUMENT_CONTEXT_PROMPT = """
<document>
{doc_content}
</document>
"""

CHUNK_CONTEXT_PROMPT = """
Here is the chunk we want to situate within the whole document
<chunk>
{chunk_content}
</chunk>

Please give a short succinct context to situate this chunk within the overall document for the purposes of improving search retrieval of the chunk.
Answer only with the succinct context and nothing else.
"""

def situate_context(doc: str, chunk: str) -> str:
    response = client.beta.prompt_caching.messages.create(
        model="claude-3-haiku-20240307",
        max_tokens=1024,
        temperature=0.0,
        messages=[
            {
                "role": "user", 
                "content": [
                    {
                        "type": "text",
                        "text": DOCUMENT_CONTEXT_PROMPT.format(doc_content=doc),
                        "cache_control": {"type": "ephemeral"} #we will make use of prompt caching for the full documents
                    },
                    {
                        "type": "text",
                        "text": CHUNK_CONTEXT_PROMPT.format(chunk_content=chunk),
                    }
                ]
            }
        ],
        extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"}
    )
    return response

In [23]:
chunk_context = situate_context(dataset["content"][0], first_chunk)

In [24]:
console.print(chunk_context)

In [25]:
second_chunk = ' '.join(chunks_0[0][1].splits)

In [26]:
second_chunk_context = situate_context(dataset["content"][0], second_chunk)

In [27]:
console.print(second_chunk_context)

In [42]:
arxiv_id = dataset[0]["id"]
refs = list(dataset[0]["references"].values())
doc_text = dataset[0]["content"]
title = dataset[0]["title"]

corpus_json = []
for i, chunk in enumerate(chunks_0[0]):
    chunk_text = ' '.join(chunk.splits)
    contextualized_text = situate_context(doc_text, chunk_text).content[0].text
    corpus_json.append({
        "text": f"{chunk_text}\n\n{contextualized_text}",
        "metadata" : {
            "title": title,
            "arxiv_id": arxiv_id,
            "references": refs
        }
    })

In [43]:
console.print(corpus_json[:2])

## Hybrid Search

We will use bm25 supported database to complement the semantic search with the vector database.

In [35]:
import bm25s

In [44]:
corpus_text = [doc["text"] for doc in corpus_json]

# Tokenize the corpus and only keep the ids (faster and saves memory)
corpus_tokens = bm25s.tokenize(corpus_text, stopwords="en")

# Create the BM25 retriever and attach your corpus_json to it
retriever = bm25s.BM25(corpus=corpus_json)
# Now, index the corpus_tokens (the corpus_json is not used yet)
retriever.index(corpus_tokens)




Split strings:   0%|          | 0/46 [00:00<?, ?it/s]

BM25S Count Tokens:   0%|          | 0/46 [00:00<?, ?it/s]

BM25S Compute Scores:   0%|          | 0/46 [00:00<?, ?it/s]

In [45]:
# Query the corpus
query = "What is context size of Mixtral?"
query_tokens = bm25s.tokenize(query)


results, scores = retriever.retrieve(query_tokens, k=2)

for i in range(results.shape[1]):
    doc, score = results[0, i], scores[0, i]
    console.print(f"Rank {i+1} (score: {score:.2f}): {doc}")

Split strings:   0%|          | 0/1 [00:00<?, ?it/s]

BM25S Retrieve:   0%|          | 0/1 [00:00<?, ?it/s]