# Enhancing RAG with Contextual Retrieval

We will use an LLM to generate for each chunk and document a contextual sentence to improve its retrival accuracy and use in hybrid search.

* Generate the context sentence.
* Enrich the chunk embedding vectors with the context.
* Create a topic database to be used in hybrid search.
* Perform hybrid search to improve retrieval results.

### Visual improvements

We will use [rich library](https://github.com/Textualize/rich) to make the output more readable, and supress warning messages.

In [2]:
from rich.pretty import pprint
from rich.theme import Theme
from rich.console import Console
from rich.panel import Panel
from rich.text import Text

custom_theme = Theme({
    "repr.own": "bright_yellow",            # Class names
    "repr.tag_name": "bright_yellow",       # Adjust tag names which might still be purple
    "repr.call": "bright_yellow",           # Function calls and other symbols
    "repr.str": "bright_green",             # String representation
    "repr.number": "bright_red",            # Numbers
    "repr.attrib_name": "bright_yellow",    # Attribute names
    "repr.attrib_value": "bright_blue"      # Attribute values
})

# Apply the theme and print the object with rich formatting

console = Console(theme=custom_theme)

In [3]:
import warnings

# Suppress warnings
warnings.filterwarnings('ignore')

## Loading a complex dataset of documents

We will load a complex dataset of scientific documents from Arxiv. Applying naive chunks on such documents will give poor results in RAG applications.

In [4]:
from datasets import load_dataset

dataset = load_dataset("jamescalam/ai-arxiv2", split="train")
console.print(dataset)

In [5]:
import os
from semantic_router.encoders import OpenAIEncoder

os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")

encoder = OpenAIEncoder(name="text-embedding-3-small")

In [6]:
from semantic_chunkers import StatisticalChunker
import logging

logging.disable(logging.CRITICAL)

chunker = StatisticalChunker(
    encoder=encoder,
    min_split_tokens=100,
    max_split_tokens=500,
)

In [7]:
chunks_0 = chunker(docs=[dataset["content"][0]])


In [11]:
first_chunk = ' '.join(chunks_0[0][0].splits)
console.print(first_chunk)

In [18]:
from dotenv import load_dotenv

load_dotenv()

True

In [19]:
import anthropic

client = anthropic.Anthropic(
    # This is the default and can be omitted
    api_key=os.getenv("ANTHROPIC_API_KEY"),
)


In [20]:
DOCUMENT_CONTEXT_PROMPT = """
<document>
{doc_content}
</document>
"""

CHUNK_CONTEXT_PROMPT = """
Here is the chunk we want to situate within the whole document
<chunk>
{chunk_content}
</chunk>

Please give a short succinct context to situate this chunk within the overall document for the purposes of improving search retrieval of the chunk.
Answer only with the succinct context and nothing else.
"""

def situate_context(doc: str, chunk: str) -> str:
    response = client.beta.prompt_caching.messages.create(
        model="claude-3-haiku-20240307",
        max_tokens=1024,
        temperature=0.0,
        messages=[
            {
                "role": "user", 
                "content": [
                    {
                        "type": "text",
                        "text": DOCUMENT_CONTEXT_PROMPT.format(doc_content=doc),
                        "cache_control": {"type": "ephemeral"} #we will make use of prompt caching for the full documents
                    },
                    {
                        "type": "text",
                        "text": CHUNK_CONTEXT_PROMPT.format(chunk_content=chunk),
                    }
                ]
            }
        ],
        extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"}
    )
    return response

In [23]:
chunk_context = situate_context(dataset["content"][0], first_chunk)

In [24]:
console.print(chunk_context)

In [25]:
second_chunk = ' '.join(chunks_0[0][1].splits)

In [26]:
second_chunk_context = situate_context(dataset["content"][0], second_chunk)

In [27]:
console.print(second_chunk_context)

In [61]:
arxiv_id = dataset[0]["id"]
refs = list(dataset[0]["references"].values())
doc_text = dataset[0]["content"]
title = dataset[0]["title"]

from tqdm import tqdm

corpus_json = []
for i, chunk in tqdm(enumerate(chunks_0[0]), total=len(chunks_0[0]), desc="Processing chunks"):
    chunk_text = ' '.join(chunk.splits)
    contextualized_text = situate_context(doc_text, chunk_text).content[0].text
    corpus_json.append({
        "id": i,
        "text": f"{chunk_text}\n\n{contextualized_text}",
        "metadata" : {
            "title": title,
            "arxiv_id": arxiv_id,
            "references": refs
        }
    })

Processing chunks: 100%|██████████| 46/46 [10:40<00:00, 13.91s/it]


In [62]:
console.print(corpus_json[:2])

## Hybrid Search

We will use bm25 supported database to complement the semantic search with the vector database.

In [35]:
import bm25s

In [63]:
corpus_text = [doc["text"] for doc in corpus_json]

# Tokenize the corpus and only keep the ids (faster and saves memory)
corpus_tokens = bm25s.tokenize(corpus_text, stopwords="en")

# Create the BM25 retriever and attach your corpus_json to it
retriever = bm25s.BM25(corpus=corpus_json)
# Now, index the corpus_tokens (the corpus_json is not used yet)
retriever.index(corpus_tokens)




Split strings:   0%|          | 0/46 [00:00<?, ?it/s]

BM25S Count Tokens:   0%|          | 0/46 [00:00<?, ?it/s]

BM25S Compute Scores:   0%|          | 0/46 [00:00<?, ?it/s]

In [88]:
# Query the corpus
query = "What is context size of Mixtral?"
query_tokens = bm25s.tokenize(query)


results, scores = retriever.retrieve(query_tokens, k=10)

for i in range(results.shape[1]):
    doc, score = results[0, i], scores[0, i]
    console.print(f"Rank {i+1} (score: {score:.2f}): {doc}")

Split strings:   0%|          | 0/1 [00:00<?, ?it/s]

BM25S Retrieve:   0%|          | 0/1 [00:00<?, ?it/s]

In [65]:
from qdrant_client import QdrantClient
from qdrant_client.http import models
from sentence_transformers import SentenceTransformer

qdrant_client = QdrantClient(
    ":memory:"
) 

# Create the embedding encoder
encoder = SentenceTransformer('all-MiniLM-L6-v2') # Model to create embeddings

In [66]:
collection_name = "hybrid_search"

first_collection = qdrant_client.recreate_collection(
    collection_name=collection_name,
        vectors_config=models.VectorParams(
        size=encoder.get_sentence_embedding_dimension(), # Vector size is defined by used model
        distance=models.Distance.COSINE
    )
)
print(first_collection)

True


In [67]:
# vectorize!
qdrant_client.upload_points(
    collection_name=collection_name,
    points=[
        models.PointStruct(
            id=idx,
            vector=encoder.encode(doc["text"]).tolist(),
            payload=doc
        ) for idx, doc in enumerate(corpus_json) # data is the variable holding all the wines
    ]
)

In [68]:
query_vector = encoder.encode(query).tolist()

In [89]:
hits = qdrant_client.search(
    collection_name=collection_name,
    query_vector=query_vector,
    limit=10
)

In [90]:
hits

[ScoredPoint(id=15, version=0, score=0.6180975806351034, payload={'id': 15, 'text': "3.8 â Mixtral_8x7B 3.5 32 > $3.0 i] 228 fos a 2.0 0 5k 10k 15k 20k 25k 30k Context length Passkey Performance ry 3.8 â Mixtral_8x7B 3.5 0.8 32 > 0.6 $3.0 i] 228 04 fos 0.2 a 2.0 0.0 OK 4K 8K 12K 16K 20K 24K 28K 0 5k 10k 15k 20k 25k 30k Seq Len Context length Figure 4: Long range performance of Mixtral. (Left) Mixtral has 100% retrieval accuracy of the Passkey task regardless of the location of the passkey and length of the input sequence. (Right) The perplexity of Mixtral on the proof-pile dataset decreases monotonically as the context length increases.\n\nThe chunk discusses the long-range performance of the Mixtral model, demonstrating its ability to retrieve a passkey regardless of its location in a long input sequence, and showing that the model's perplexity on the proof-pile dataset decreases as the context length increases.", 'metadata': {'title': 'Mixtral of Experts', 'arxiv_id': '2401.04088', '

In [91]:
documents_with_scores = []
for hit in hits:
    doc_id = hit.payload["id"]
    doc_text = next((doc for doc in corpus_json if doc["id"] == doc_id), None)["text"]
    doc_dense_score = hit.score
    documents_with_scores.append({
        "id": doc_id,
        "text": doc_text,
        "dense_score": doc_dense_score
    })

for i, result in enumerate(results[0]):
    doc_id = result["id"]
    doc_text = next((doc for doc in corpus_json if doc["id"] == doc_id), None)["text"]
    doc_sparse_score = scores[0][i]
    for doc in documents_with_scores:
        if doc["id"] == doc_id:
            doc["sparse_score"] = doc_sparse_score
            break




In [92]:
console.print(documents_with_scores)

In [96]:
import numpy as np

# Normalize the two types of scores
dense_scores = np.array([doc.get("dense_score", 0) for doc in documents_with_scores])
sparse_scores = np.array([doc.get("sparse_score", 0) for doc in documents_with_scores])

dense_scores_normalized = (dense_scores - np.min(dense_scores)) / (np.max(dense_scores) - np.min(dense_scores))
sparse_scores_normalized = (sparse_scores - np.min(sparse_scores)) / (np.max(sparse_scores) - np.min(sparse_scores))

# Calculate a weighted score with alpha of 0.2 to the sparse score
alpha = 0.2
weighted_scores = (1 - alpha) * dense_scores_normalized + alpha * sparse_scores_normalized

# Pick up the top 3 documents with the weighted score
top_docs = sorted(
    zip(
        documents_with_scores, 
        weighted_scores
    ), 
    key=lambda x: x[1], 
    reverse=True
)[:3]



In [95]:
console.print(top_docs)

In [97]:
# define a variable to hold the search results
search_results = [doc[0]['text'] for doc in top_docs]

In [98]:
# Now time to connect to the large language model
from openai import OpenAI
client = OpenAI()
completion = client.chat.completions.create(
    model="gpt-3.5-turbo",
    messages=[
        {"role": "system", "content": "You are chatbot, an research expert. Your top priority is to help guide users to understand reserach papers."},
        {"role": "user", "content": query},
        {"role": "assistant", "content": str(search_results)}
    ]
)

response_text = Text(completion.choices[0].message.content)

In [99]:
response_text