# Enhancing RAG with Contextual Retrieval

We will use an LLM to generate for each chunk and document a contextual sentence to improve its retrival accuracy and use in hybrid search.

* [Load complex documents dataset](#loading-a-complex-dataset-of-documents)
* [Split the documents into chunks](#split-the-documents-into-chunks)
* [Generate the context sentence](#generate-the-context-sentence)
* [Enrich the chunk embedding vectors with the context](#enrich-the-chunk-embedding-vectors-with-the-context)

### Visual improvements

We will use [rich library](https://github.com/Textualize/rich) to make the output more readable, and supress warning messages.

In [1]:
from rich.console import Console
from rich_theme_manager import Theme, ThemeManager
import pathlib

theme_dir = pathlib.Path("themes")
theme_manager = ThemeManager(theme_dir=theme_dir)
dark = theme_manager.get("dark")

# Create a console with the dark theme
console = Console(theme=dark)

In [2]:
import warnings

# Suppress warnings
warnings.filterwarnings('ignore')

## Loading a complex dataset of documents

We will load a complex dataset of scientific documents from Arxiv. Applying naive chunks on such documents will give poor results in RAG applications.

In [3]:
from datasets import load_dataset

dataset = load_dataset("jamescalam/ai-arxiv2", split="train")
console.print(dataset)

## Split the documents into Chunks

We will use the statistical chunker that we used in a previous notebook.

In [4]:
from dotenv import load_dotenv

load_dotenv()

True

In [5]:
import os
from semantic_router.encoders import OpenAIEncoder

encoder = OpenAIEncoder(name="text-embedding-3-small")

In [6]:
from semantic_chunkers import StatisticalChunker
import logging

logging.disable(logging.CRITICAL)

chunker = StatisticalChunker(
    encoder=encoder,
    min_split_tokens=100,
    max_split_tokens=500,
)

In [7]:
chunks_0 = chunker(docs=[dataset["content"][0]])


In [8]:
from rich.text import Text
from rich.panel import Panel

chunk_0_0 = ' '.join(chunks_0[0][0].splits)

content = Text(chunk_0_0)
console.print(Panel(content, title=f"Chunk 0", expand=False, border_style="bold"))

## Generate the context sentence

We will use Anthropic Claude for the generation of the context. It is one of the best summarization LLM, and it introduced the [Prompt Caching](https://www.anthropic.com/news/prompt-caching) that is great for the generation of the context for many chunks of the same document.

In [9]:
from dotenv import load_dotenv

load_dotenv()

True

In [10]:
import anthropic

client = anthropic.Anthropic()


In [11]:
DOCUMENT_CONTEXT_PROMPT = """
<document>
{doc_content}
</document>
"""

CHUNK_CONTEXT_PROMPT = """
Here is the chunk we want to situate within the whole document
<chunk>
{chunk_content}
</chunk>

Please give a short succinct context to situate this chunk within the overall document for the purposes of improving search retrieval of the chunk.
Answer only with the succinct context and nothing else.
"""

def situate_context(doc: str, chunk: str) -> str:
    response = client.beta.prompt_caching.messages.create(
        model="claude-3-haiku-20240307",
        max_tokens=1024,
        temperature=0.0,
        messages=[
            {
                "role": "user", 
                "content": [
                    {
                        "type": "text",
                        "text": DOCUMENT_CONTEXT_PROMPT.format(doc_content=doc),
                        "cache_control": {"type": "ephemeral"} #we will make use of prompt caching for the full documents
                    },
                    {
                        "type": "text",
                        "text": CHUNK_CONTEXT_PROMPT.format(chunk_content=chunk),
                    }
                ]
            }
        ],
        extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"}
    )
    return response

In [12]:
chunk_context = situate_context(dataset["content"][0], chunk_0_0)

In [13]:
console.print(chunk_context)

In [14]:
chunk_0_5 = ' '.join(chunks_0[0][5].splits)

In [15]:
second_chunk_context = situate_context(dataset["content"][0], chunk_0_5)

In [16]:
console.print(second_chunk_context)

## Enrich the chunk embedding vectors with the context

### Concatenate the generated context to the chunk text

We will iterate over all the chunks. This can take some time based on the number of chunks.

In [17]:
arxiv_id = dataset[0]["id"]
refs = list(dataset[0]["references"].values())
doc_text = dataset[0]["content"]
title = dataset[0]["title"]

from tqdm import tqdm

corpus_json = []
for i, chunk in tqdm(enumerate(chunks_0[0]), total=len(chunks_0[0]), desc="Processing chunks"):
    chunk_text = ' '.join(chunk.splits)
    contextualized_text = situate_context(doc_text, chunk_text).content[0].text
    corpus_json.append({
        "id": i,
        "text": f"{chunk_text}\n\n{contextualized_text}",
        "metadata" : {
            "title": title,
            "arxiv_id": arxiv_id,
            "references": refs
        }
    })

Processing chunks: 100%|██████████| 46/46 [26:50<00:00, 35.00s/it] 


In [18]:
console.print(corpus_json[:2])

### Saving the corpus_json in a file

We will want to use it in the next notebook.

In [19]:
import json

with open('data/corpus.json', 'w') as f:
    json.dump(corpus_json, f)

