# Late Chunking

This notebook is a work in progress and that shows how to perform late chunking.

In [None]:
%pip install -r "../../requirements.txt"

In [None]:
from docling.document_converter import DocumentConverter

# We will be using docling to conver the PDF into markdown.

# This might run for a couple minutes as the PDF is fairly large.
source = "../../fixtures/Delta Lake Definitive Guide.pdf"
converter = DocumentConverter()

In [None]:
result = converter.convert(source)

document = result.document.export_to_markdown()

In [None]:
from transformers import AutoModel, AutoTokenizer

MODEL_NAME = "jinaai/jina-embeddings-v2-base-en"

# for this model you need to define the embeddings task; which can not be done with the late chunking function
# It will use a non lora adapter; so just a regular embedding instead of query of document specialisation
# MODEL_NAME = "jinaai/jina-embeddings-v3"

jina_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
jina_model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True)

# Full


In [None]:
def late_chunking(document, model, tokenizer):
    "Implements late chunking on a document."

    # Tokenize with offset mapping to find sentence boundaries
    inputs_with_offsets = tokenizer(
        document, return_tensors="pt", return_offsets_mapping=True
    )
    token_offsets = inputs_with_offsets["offset_mapping"][0]
    token_ids = inputs_with_offsets["input_ids"][0]

    # Find chunk boundaries
    punctuation_mark_id = tokenizer.convert_tokens_to_ids(".")
    chunk_positions, token_span_annotations = [], []
    span_start_char, span_start_token = 0, 0

    for i, (token_id, (start, end)) in enumerate(zip(token_ids, token_offsets)):
        if i < len(token_ids) - 1:
            if token_id == punctuation_mark_id and document[end : end + 1] in [
                " ",
                "\n",
            ]:
                # Store both character positions and token positions
                chunk_positions.append((span_start_char, int(end)))
                token_span_annotations.append((span_start_token, i + 1))

                # Update start positions for next chunk
                span_start_char, span_start_token = int(end) + 1, i + 1

    # Create text chunks from character positions
    chunks = [document[start:end].strip() for start, end in chunk_positions]

    # Encode the entire document
    inputs = tokenizer(document, return_tensors="pt")
    model_output = model(**inputs)
    token_embeddings = model_output[0]

    # Create embeddings for each chunk using mean pooling
    embeddings = []
    for start_token, end_token in token_span_annotations:
        if end_token > start_token:  # Ensure span has at least one token
            chunk_embedding = token_embeddings[0, start_token:end_token].mean(dim=0)
            embeddings.append(chunk_embedding.detach().cpu().numpy())

    return chunks, embeddings

In [None]:
from langchain_text_splitters import MarkdownHeaderTextSplitter

# Late chunking the full document is too memory intensive, so we'll split the document into chunks by header.
# These are then further chunked using late chunking

headers_to_split_on = [
    ("#", "Header 1"),
    ("##", "Header 2"),
    # ("###", "Header 3"),
]

markdown_splitter = MarkdownHeaderTextSplitter(
    headers_to_split_on,
    # return_each_line=True,
)
md_header_splits = markdown_splitter.split_text(document)

print(f"Number of chunks: {len(md_header_splits)}")

for document in md_header_splits:
    tokens = (
        jina_tokenizer(document.page_content, return_tensors="pt")
        .get("input_ids")
        .shape[1]
    )
    if tokens <= 32000:
        continue
    else:
        print(f"too many tokens: {tokens}")

In [None]:
late_chunks = []
late_embeddings = []

for document in md_header_splits:
    tmp_late_chunks, tmp_late_embeddings = late_chunking(
        document.page_content, jina_model, jina_tokenizer
    )
    late_chunks.extend(tmp_late_chunks)
    late_embeddings.extend(tmp_late_embeddings)