<a href="https://colab.research.google.com/github/zamanmiraz/DSandML-Notebooks/blob/main/RAG/04_contextual_retrieval.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/guyernest/advanced-rag.git
%cd advanced-rag
!pip install --upgrade -r requirements.txt

In [None]:
!pip install torchvision==0.18.0
!pip install -q -U google-generativeai

In [None]:
from rich.console import Console
from rich.style import Style
import pathlib
from rich_theme_manager import Theme, ThemeManager

THEMES = [
    Theme(
        name="dark",
        description="Dark mode theme",
        tags=["dark"],
        styles={
            "repr.own": Style(color="#e87d3e", bold=True),      # Class names
            "repr.tag_name": "dim cyan",                        # Adjust tag names
            "repr.call": "bright_yellow",                       # Function calls and other symbols
            "repr.str": "bright_green",                         # String representation
            "repr.number": "bright_red",                        # Numbers
            "repr.none": "dim white",                           # None
            "repr.attrib_name": Style(color="#e87d3e", bold=True),    # Attribute names
            "repr.attrib_value": "bright_blue",                 # Attribute values
            "default": "bright_white on black"                  # Default text and background
        },
    ),
    Theme(
        name="light",
        description="Light mode theme",
        styles={
            "repr.own": Style(color="#22863a", bold=True),          # Class names
            "repr.tag_name": Style(color="#00bfff", bold=True),     # Adjust tag names
            "repr.call": Style(color="#ffff00", bold=True),         # Function calls and other symbols
            "repr.str": Style(color="#008080", bold=True),          # String representation
            "repr.number": Style(color="#ff6347", bold=True),       # Numbers
            "repr.none": Style(color="#808080", bold=True),         # None
            "repr.attrib_name": Style(color="#ffff00", bold=True),  # Attribute names
            "repr.attrib_value": Style(color="#008080", bold=True), # Attribute values
            "default": Style(color="#000000", bgcolor="#ffffff"),   # Default text and background
        },
    ),
]

theme_dir = pathlib.Path("themes").expanduser()
theme_dir.expanduser().mkdir(parents=True, exist_ok=True)

theme_manager = ThemeManager(theme_dir=theme_dir, themes=THEMES)
theme_manager.list_themes()

dark = theme_manager.get("dark")
theme_manager.preview_theme(dark)
light = theme_manager.get("light")

console = Console(theme=light)

In [None]:
from datasets import load_dataset
dataset = load_dataset("jamescalam/ai-arxiv2", split="train")
console.print(dataset)

In [None]:
from typing import Dict
import google.generativeai as genai
from semantic_chunkers import StatisticalChunker
from google.colab import userdata
import logging

logging.disable(logging.CRITICAL)

# Configure API
GOOGLE_API_KEY = userdata.get('GOOGLE_API_KEY')
genai.configure(api_key=GOOGLE_API_KEY)

class GeminiEncoder(dict):
    def __init__(self, model_name="models/text-embedding-004", score_threshold=0.3):
        super().__init__(name=model_name, score_threshold=score_threshold)
        self.model_name=model_name
    def __call__(self, docs):
        return [genai.embed_content(model=self.model_name, content=doc)["embedding"]
                for doc in docs]

encoder = GeminiEncoder()

chunker = StatisticalChunker(
    encoder = encoder,
    min_split_tokens=100,
    max_split_tokens=500
)


In [None]:
chunks_0 = chunker(docs=[dataset["content"][0]])

In [None]:
from rich.text import Text
from rich.panel import Panel

chunk_0_0 = ' '.join(chunks_0[0][0].splits)

content = Text(chunk_0_0)
console.print(Panel(content, title=f"Chunk 0", expand=False, border_style="bold"))

In [None]:
from google import genai
from google.genai import types
from google.colab import userdata

# --- PROMPTS ---
DOCUMENT_CONTEXT_PROMPT = """
<document>
{doc_content}
</document>
"""

CHUNK_CONTEXT_PROMPT = """
Here is the chunk we want to situate within the whole document
<chunk>
{chunk_content}
</chunk>

Please give a short succinct context to situate this chunk within the overall document
for the purposes of improving search retrieval of the chunk.
Answer only with the succinct context and nothing else.
"""

# --- CREATE CLIENT (pass API key here) ---
GOOGLE_API_KEY = userdata.get('GOOGLE_API_KEY')
client = genai.Client(api_key=GOOGLE_API_KEY)

# --- FUNCTION ---
def situate_context(doc: str, chunk: str) -> str:
    model = "models/gemini-2.0-flash-001"  # Must include version suffix

    # ✅ Create cache
    cache = client.caches.create(
        model=model,
        config=types.CreateCachedContentConfig(
            display_name="document_context_cache",
            system_instruction=(
                "You are helping generate semantic context for document chunks "
                "to improve retrieval accuracy."
            ),
            contents=[
                {"role": "user", "parts": [{"text": DOCUMENT_CONTEXT_PROMPT.format(doc_content=doc)}]}
            ],
            ttl="300s",
        ),
    )

    # ✅ Use cache
    response = client.models.generate_content(
        model=model,
        contents=[
            {"role": "user", "parts": [{"text": CHUNK_CONTEXT_PROMPT.format(chunk_content=chunk)}]},
        ],
        config=types.GenerateContentConfig(cached_content=cache.name),
    )

    return response.text.strip()


In [None]:
chunk_context = situate_context(dataset["content"][0], chunk_0_0)

In [None]:
console.print(chunk_context)

In [None]:
arxiv_id = dataset[0]["id"]
refs = list(dataset[0]["references"].values())
doc_text = dataset[0]["content"]
title = dataset[0]["title"]

from tqdm import tqdm

corpus_json = []
for i, chunk in tqdm(enumerate(chunks_0[0]), total=len(chunks_0[0]), desc="Processing chunks"):
    chunk_text = ' '.join(chunk.splits)
    contextualized_text = situate_context(doc_text, chunk_text).text
    corpus_json.append({
        "id": i,
        "text": f"{chunk_text}\n\n{contextualized_text}",
        "metadata" : {
            "title": title,
            "arxiv_id": arxiv_id,
            "references": refs
        }
    })

In [None]:
import json
import os

# Create the data directory if it doesn't exist
os.makedirs('data', exist_ok=True)

with open('data/corpus.json', 'w') as f:
    json.dump(corpus_json, f)