<a href="https://colab.research.google.com/github/zamanmiraz/DSandML-Notebooks/blob/main/RAG/03_semantic_chunking.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/guyernest/advanced-rag.git
%cd advanced-rag
!pip install --upgrade -r requirements.txt

In [None]:
!pip install torchvision==0.18.0
!pip install -q -U google-generativeai

In [None]:
from rich.console import Console
from rich.style import Style
import pathlib
from rich_theme_manager import Theme, ThemeManager

THEMES = [
    Theme(
        name="dark",
        description="Dark mode theme",
        tags=["dark"],
        styles={
            "repr.own": Style(color="#e87d3e", bold=True),      # Class names
            "repr.tag_name": "dim cyan",                        # Adjust tag names
            "repr.call": "bright_yellow",                       # Function calls and other symbols
            "repr.str": "bright_green",                         # String representation
            "repr.number": "bright_red",                        # Numbers
            "repr.none": "dim white",                           # None
            "repr.attrib_name": Style(color="#e87d3e", bold=True),    # Attribute names
            "repr.attrib_value": "bright_blue",                 # Attribute values
            "default": "bright_white on black"                  # Default text and background
        },
    ),
    Theme(
        name="light",
        description="Light mode theme",
        styles={
            "repr.own": Style(color="#22863a", bold=True),          # Class names
            "repr.tag_name": Style(color="#00bfff", bold=True),     # Adjust tag names
            "repr.call": Style(color="#ffff00", bold=True),         # Function calls and other symbols
            "repr.str": Style(color="#008080", bold=True),          # String representation
            "repr.number": Style(color="#ff6347", bold=True),       # Numbers
            "repr.none": Style(color="#808080", bold=True),         # None
            "repr.attrib_name": Style(color="#ffff00", bold=True),  # Attribute names
            "repr.attrib_value": Style(color="#008080", bold=True), # Attribute values
            "default": Style(color="#000000", bgcolor="#ffffff"),   # Default text and background
        },
    ),
]

theme_dir = pathlib.Path("themes").expanduser()
theme_dir.expanduser().mkdir(parents=True, exist_ok=True)

theme_manager = ThemeManager(theme_dir=theme_dir, themes=THEMES)
theme_manager.list_themes()

dark = theme_manager.get("dark")
theme_manager.preview_theme(dark)

In [None]:
from rich.console import Console

dark = theme_manager.get("dark")
light = theme_manager.get("light")

console = Console(theme=light)

In [None]:
from datasets import load_dataset
dataset = load_dataset("jamescalam/ai-arxiv2", split="train")
console.print(dataset)

In [None]:
len(dataset)

In [None]:
content = dataset[3]["content"]
console.print(content[:500])

In [None]:
import google.generativeai as genai
from semantic_chunkers import StatisticalChunker
from google.colab import userdata
import logging

# Disable logs
logging.disable(logging.CRITICAL)

# Configure API
GOOGLE_API_KEY = userdata.get('GOOGLE_API_KEY')
genai.configure(api_key=GOOGLE_API_KEY)


# ✅ Hybrid encoder: dict + callable
class GeminiEncoder(dict):
    def __init__(self, model_name="models/text-embedding-004", score_threshold=0.3):
        super().__init__(name=model_name, score_threshold=score_threshold)
        self.model_name = model_name

    def __call__(self, docs):
        return [
            genai.embed_content(model=self.model_name, content=doc)["embedding"]
            for doc in docs
        ]

# Create encoder instance
encoder = GeminiEncoder()

# ✅ Initialize chunker
chunker = StatisticalChunker(
    encoder=encoder,
    min_split_tokens=100,
    max_split_tokens=500,
    plot_chunks=True,
    enable_statistics=True
)

# ✅ Example usage
chunks_0 = chunker(docs=[dataset["content"][0]])


In [None]:
console.print(chunks_0[0][:5])

In [None]:
chunks_1 = chunker(docs=[dataset["content"][1]])

In [None]:
console.print(chunks_1[0][:3])

In [None]:
from rich.text import Text
from rich.panel import Panel

def build_chunk(title: str, content: str):
    return f"# {title}\n{content}"

title = dataset[0]["title"]
for i,s in enumerate(chunks_0[0][:3]):
    content = Text(build_chunk(title=title, content=s.content))
    console.print(Panel(content, title=f"Chunk {i + 1}", expand=False, border_style="bold"))

In [None]:
arxiv_id = dataset[0]["id"]
refs = list(dataset[0]["references"].values())

metadata = []
for i, chunk in enumerate(chunks_0[0]):
    prechunk = "" if i == 0 else chunks_0[0][i-1].content
    postchunk = "" if i+1 >= len(chunks_0[0]) else chunks_0[0][i+1].content
    metadata.append({
        "title": title,
        "content": chunk.content,
        "prechunk": prechunk,
        "postchunk": postchunk,
        "arxiv_id": arxiv_id,
        "references": refs
    })

In [None]:
def build_metadata(doc: dict, doc_splits):
    # get document level metadata first
    arxiv_id = doc["id"]
    title = doc["title"]
    refs = list(doc["references"].values())
    # init split level metadata list
    metadata = []
    for i, split in enumerate(doc_splits):
        # get neighboring chunks
        prechunk_id = "" if i == 0 else f"{arxiv_id}#{i-1}"
        postchunk_id = "" if i+1 >= len(doc_splits) else f"{arxiv_id}#{i+1}"
        # create dict and append to metadata list
        metadata.append({
            "id": f"{arxiv_id}#{i}",
            "title": title,
            "content": split.content,
            "prechunk_id": prechunk_id,
            "postchunk_id": postchunk_id,
            "arxiv_id": arxiv_id,
            "references": refs
        })
    return metadata

In [None]:
metadata = build_metadata(
    doc=dataset[0],
    doc_splits=chunks_0[0][:3]
)

In [None]:
console.print(metadata)