In [None]:
import re
from typing import List, Optional
from transformers import AutoTokenizer


class MarkdownChunker:
    """
    Chunk Markdown text into pieces that fit within a token budget using a tokenizer.
    Strategy:
      - Split by Markdown headers first to preserve structure.
      - For oversized sections, split hierarchically: paragraphs -> lines -> sentences -> words.
      - Greedily pack adjacent units at each level while respecting max_tokens.
      - As a last resort (e.g., a single long word), fall back to character-level greedy splitting.
    """

    def __init__(self, tokenizer: Optional[AutoTokenizer] = None, tokenizer_name: str = "gpt2"):
        self.tokenizer = tokenizer or AutoTokenizer.from_pretrained(tokenizer_name)
        if self.tokenizer.pad_token is None and getattr(self.tokenizer, "eos_token", None):
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # Precompile patterns
        self.header_re = re.compile(r"^(#{1,6}\s+.*)$", re.MULTILINE)

    def count_tokens(self, text: str) -> int:
        if not text:
            return 0
        return len(self.tokenizer.encode(text, add_special_tokens=False))

    def chunk(self, markdown_text: str, max_tokens: int) -> List[str]:
        if not markdown_text or markdown_text.isspace():
            return []

        sections = self._split_by_headers(markdown_text)

        # Split oversized sections first to ensure everything fits
        normalized_sections: List[str] = []
        for sec in sections:
            normalized_sections.extend(self._split_section_to_fit(sec, max_tokens))

        # Greedy merge adjacent sections while respecting max_tokens
        return self._merge_sections(normalized_sections, max_tokens)

    # --- Internal helpers ---

    def _split_by_headers(self, text: str) -> List[str]:
        parts = self.header_re.split(text)
        sections: List[str] = []
        preface = parts[0].strip() if parts and parts[0] else ""

        if preface:
            sections.append(preface)

        for i in range(1, len(parts), 2):
            header = parts[i].strip() if parts[i] else ""
            content = (parts[i + 1] or "").strip() if i + 1 < len(parts) else ""
            if header and content:
                sections.append(f"{header}\n{content}")
            elif header:
                sections.append(header)
            elif content:
                sections.append(content)

        if not sections and parts and parts[0]:
            sections.append(parts[0].strip())

        return [s for s in sections if s.strip()]

    def _merge_sections(self, sections: List[str], max_tokens: int) -> List[str]:
        chunks: List[str] = []
        current = ""

        for sec in sections:
            tentative = f"{current}\n\n{sec}" if current else sec
            if self.count_tokens(tentative) <= max_tokens:
                current = tentative
            else:
                if current:
                    chunks.append(current.strip())
                current = sec

        if current:
            chunks.append(current.strip())
        return chunks

    def _split_section_to_fit(self, section: str, max_tokens: int) -> List[str]:
        if self.count_tokens(section) <= max_tokens:
            return [section]

        header, content = self._extract_header(section)

        if header:
            # Split content first
            content_chunks = self._split_text_hierarchically(content, max_tokens)

            # Try to attach header to the first content chunk
            if content_chunks:
                first_with_header = f"{header}\n{content_chunks[0]}"
                if self.count_tokens(first_with_header) <= max_tokens:
                    out = [first_with_header]
                    out.extend(content_chunks[1:])
                    return out

            # If header alone fits, keep it as its own chunk
            if self.count_tokens(header) <= max_tokens:
                return [header] + content_chunks

            # If header itself is too large, split the header too
            header_chunks = self._split_text_hierarchically(header, max_tokens)
            return header_chunks + content_chunks

        # No header: split the whole section
        return self._split_text_hierarchically(section, max_tokens)

    def _extract_header(self, section: str):
        lines = section.splitlines()
        if lines and re.match(r"^#{1,6}\s+.*$", lines[0]):
            header_line = lines[0]
            rest = "\n".join(lines[1:]).strip()
            return header_line, rest
        return None, section

    def _split_text_hierarchically(self, text: str, max_tokens: int) -> List[str]:
        """
        Attempts progressively finer splits and greedy packing at each level.
        Always returns a non-empty list if text is non-empty; if a single atomic unit
        still exceeds the budget, it will be emitted alone, or broken by characters as a last resort.
        """
        if not text:
            return []

        # Stop if already fits
        if self.count_tokens(text) <= max_tokens:
            return [text]

        # Levels: paragraphs -> lines -> sentences -> words
        levels = [
            ("paragraphs", self._split_paragraphs, "\n\n"),
            ("lines", self._split_lines, "\n"),
            ("sentences", self._split_sentences, " "),
            ("words", self._split_words, " "),
        ]

        units = [text]
        joiner_for_level = ""

        for _, splitter, joiner in levels:
            # Expand any unit that is too large at current level
            next_units: List[str] = []
            any_split = False
            for u in units:
                if self.count_tokens(u) <= max_tokens:
                    next_units.append(u)
                    continue
                parts = splitter(u)
                if len(parts) == 1:
                    next_units.append(u)
                else:
                    any_split = True
                    next_units.extend([p for p in parts if p.strip() != ""])
            units = next_units
            joiner_for_level = joiner

            # After splitting at this level, try to pack
            packed = self._pack_units(units, joiner_for_level, max_tokens)
            if packed is not None:
                return packed

            # If we couldn't split anything at this level, try the next
            if not any_split:
                continue
            return self._pack_units(units, joiner_for_level, max_tokens) or [text]


    def _pack_units(self, units: List[str], joiner: str, max_tokens: int) -> Optional[List[str]]:
        """
        Greedily joins adjacent units with the given joiner while staying under max_tokens.
        Returns None if any individual unit already exceeds max_tokens (needs deeper split).
        """
        chunks: List[str] = []
        current = ""

        for unit in units:
            if self.count_tokens(unit) > max_tokens:
                return None  # needs deeper split
            tentative = f"{current}{joiner}{unit}" if current else unit
            if self.count_tokens(tentative) <= max_tokens:
                current = tentative
            else:
                if current:
                    chunks.append(current.strip())
                current = unit

        if current:
            chunks.append(current.strip())
        return chunks

    # Splitters

    def _split_paragraphs(self, text: str) -> List[str]:
        return [p.strip() for p in re.split(r"\n{2,}", text) if p.strip() != ""]

    def _split_lines(self, text: str) -> List[str]:
        return [ln for ln in text.split("\n") if ln.strip() != ""]

    def _split_sentences(self, text: str) -> List[str]:
        # Naive sentence splitter that keeps punctuation
        parts = re.findall(r'.*?(?:[\.!\?](?!\w)|$)', text, flags=re.S)
        return [p.strip() for p in parts if p and p.strip()]

    def _split_words(self, text: str) -> List[str]:
        words = text.split()
        return words if words else [text]



In [15]:
sample = """# Introduction

This is an introductory paragraph that provides a general overview of the document. It should be grouped with the header above when possible, unless it's too large to fit in a chunk.

## Section One

Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.

### Subsection A

Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.

#### Small Header

Short line here.

##### Even Smaller Header

Another short line.

###### Tiny Header

Final small header.

## Section Two

This section has a long list:
- Item one
- Item two
- Item three
- Item four
- Item five
- Item six
- Item seven
- Item eight
- Item nine
- Item ten

And now we have a numbered list:
1. First item
2. Second item
3. Third item
4. Fourth item
5. Fifth item

## Code Example

Here’s some code:

```python
def hello_world():
    print("Hello, world!")
    for i in range(10000):
        print(f"Number: {i}")
```

## Blockquote

> This is a blockquote.  
> It spans multiple lines.  
> Let's make sure it chunks correctly.

## Very Long Line

This is a very long sentence that just keeps going on and on and on without any natural break points, which might force the chunker to split at the token level if nothing else works. We want to ensure that even such awkward content gets handled gracefully by the chunking logic.

## Final Section

Just a final wrap-up section to close things off. This paragraph may be short or combined with others depending on the max token limit used during chunking.

"""


chunker = MarkdownChunker(tokenizer_name="mixedbread-ai/mxbai-embed-xsmall-v1")

chunks = chunker.chunk(sample, max_tokens=100)  # Try different values: 50, 75, 100, etc.

for i, chunk in enumerate(chunks):
    print(f"--- Chunk {i+1} ---")
    print(chunk)
    print()


--- Chunk 1 ---
# Introduction
This is an introductory paragraph that provides a general overview of the document. It should be grouped with the header above when possible, unless it's too large to fit in a chunk.

--- Chunk 2 ---
## Section One
Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.

--- Chunk 3 ---
### Subsection A
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.

#### Small Header
Short line here.

--- Chunk 4 ---
##### Even Smaller Header
Another short line.

###### Tiny Header
Final small header.

## Section Two
This section has a long list:
- Item one
- Item two
- Item three
- Item four
- Item five
- Item six
- Item seven
- Item