# Updates on the Chunking Algorithm

<a target="_blank" href="https://colab.research.google.com/github/sweepai/sweep/blob/main/notebooks/chunking.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

This notebook is for the blog on improvements to our chunking algorithm. 

In [65]:
!pip install tree_sitter_languages

import re
from tree_sitter_languages import get_language, get_parser

language = get_language("python")
parser = get_parser("python")

## Meet the Span

In [140]:
from __future__ import annotations
from dataclasses import dataclass


@dataclass
class Span:
    # Represents a slice of a string
    start: int = 0
    end: int = 0

    def __post_init__(self):
        # If end is None, set it to start
        if self.end is None:
            self.end = self.start

    def extract(self, s: str) -> str:
        # Grab the corresponding substring of string s by bytes
        return s[self.start : self.end]

    def extract_lines(self, s: str) -> str:
        # Grab the corresponding substring of string s by lines
        return "\n".join(s.splitlines()[self.start : self.end])

    def __add__(self, other: Span | int) -> Span:
        # e.g. Span(1, 2) + Span(2, 4) = Span(1, 4) (concatenation)
        # There are no safety checks: Span(a, b) + Span(c, d) = Span(a, d)
        # and there are no requirements for b = c.
        if isinstance(other, int):
            return Span(self.start + other, self.end + other)
        elif isinstance(other, Span):
            return Span(self.start, other.end)
        else:
            raise NotImplementedError()

    def __len__(self) -> int:
        # i.e. Span(a, b) = b - a
        return self.end - self.start

The example code we're gonna use in this guide will be from https://github.com/sweepai/sweep/blob/b267b613d4c706eaf959fe6789f11e9a856521d1/sweepai/handlers/on_check_suite.py, our old handler for parsing GitHub Action run logs at Sweep.

In [158]:
import requests

example_file = "https://raw.githubusercontent.com/sweepai/sweep/b267b613d4c706eaf959fe6789f11e9a856521d1/sweepai/handlers/on_check_suite.py"
python_code = requests.get(example_file).text

Let's first visualize the syntax tree.

In [None]:
tree = parser.parse(python_code.encode("utf-8"))


def pretty_node(node):
    return f"{node.type}:{node.start_byte}-{node.end_byte}"


def print_tree(node, indent=""):
    if len(re.sub("\s", "", node.text.decode("utf-8"))) < 100:
        return
    print(indent + pretty_node(node))
    for child in node.children:
        print_tree(child, indent=indent + "  ")


for child in tree.root_node.children:
    print_tree(child)

We can see that it doesn't actually line up:

```python
expression_statement:432-696
  assignment:432-696
    string:446-696
function_definition:698-1502
  block:777-1502
    expression_statement:777-953
      assignment:777-953
        dictionary:787-953
```

Notice that the “expression_statement” ends on byte 696 and “function_definition” starts on byte 698, skipping a byte.

In [67]:
def connect_chunks(chunks: list[Span]):
    for prev, curr in zip(chunks[:-1], chunks[1:]):
        prev.end = curr.start
    return chunks

## Coalescing

Here was the algo presented in the last blog. Unfortunately it has some bugs.

In [None]:
from tree_sitter import Node
from dataclasses import field


def chunk_node(node: Node, text: str, MAX_CHARS: int = 600) -> list[str]:
    chunks = []
    current_chunk = ""
    for child in node.children:
        if child.end_byte - child.start_byte > MAX_CHARS:
            chunks.append(current_chunk)
            current_chunk = ""
            chunks.extend(chunk_node(child, text, MAX_CHARS))
        elif child.end_byte - child.start_byte + len(current_chunk) > MAX_CHARS:
            chunks.append(current_chunk)
            current_chunk = text[child.start_byte : child.end_byte]
        else:
            current_chunk += text[child.start_byte : child.end_byte]
    chunks.append(current_chunk)

    return chunks


for chunk in chunk_node(tree.root_node, python_code):
    print(chunk + "\n\n====================\n\n")

Here it is with the fixes by using the start_byte of the next node instead of the end_byte of the current node. 

I added a fake node at the end with start and end bytes equal to the end byte of the entire node. This is so that we don't need to rewrite the loop logic one last time for the last node. The purpose of MockNode is because the tree_sitter Node library doesn't have a constructor.

In [None]:
@dataclass
class MockNode:
    start_byte: int = 0
    end_byte: int = 0
    children: list[MockNode] = field(default_factory=list)


def chunk_node(node: Node, text: str, MAX_CHARS: int = 600) -> list[str]:
    chunks = []
    current_chunk = ""
    node_children = node.children + [MockNode(node.end_byte, node.end_byte)]

    for child, next_child in zip(node_children[:-1], node_children[1:]):
        if child.end_byte - child.start_byte > MAX_CHARS:
            chunks.append(current_chunk)
            current_chunk = ""
            chunks.extend(chunk_node(child, text, MAX_CHARS))
        elif child.end_byte - child.start_byte + len(current_chunk) > MAX_CHARS:
            chunks.append(current_chunk)
            current_chunk = text[child.start_byte : next_child.start_byte]
        else:
            current_chunk += text[child.start_byte : next_child.start_byte]
    chunks.append(current_chunk)

    return chunks


for chunk in chunk_node(tree.root_node, python_code):
    print(chunk + "\n\n====================\n\n")

Firstly, using Span's we can clean up the code a bit. Like removing the MockNode altogether.

In [None]:
def chunk_node(
    node: Node,
    MAX_CHARS: int = 600,
) -> list[Span]:
    chunks: list[Span] = []
    current_chunk: Span = Span(node.start_byte, node.start_byte)
    node_children = node.children
    for child in node_children:
        if child.end_byte - child.start_byte > MAX_CHARS:
            chunks.append(current_chunk)
            current_chunk = Span(child.end_byte, child.end_byte)
            chunks.extend(chunk_node(child, MAX_CHARS))
        elif child.end_byte - child.start_byte + len(current_chunk) > MAX_CHARS:
            chunks.append(current_chunk)
            current_chunk = Span(child.start_byte, child.end_byte)
        else:
            current_chunk += Span(child.start_byte, child.end_byte)
    chunks.append(current_chunk)
    return chunks


for chunk in chunk_node(tree.root_node):
    print(chunk)

## Skipping Whitespace when Measuring Length

Gives heavily indented code the same number of lines per code.

In [152]:
def char_len(s: str) -> int:  # old len function
    return len(s)


def non_whitespace_len(s: str) -> int:  # new len function
    return len(re.sub("\s", "", s))

## Coalescing Chunks

Combining smaller chunks with larger ones.

In [None]:
def coalesce_chunks(
    chunks: list[Span], source_code: str, coalesce: int = 50
) -> list[Span]:
    new_chunks = []
    current_chunk = Span(0, 0)
    for chunk in chunks:
        current_chunk += chunk
        if len(current_chunk) > coalesce and "\n" in current_chunk.extract(source_code):
            new_chunks.append(current_chunk)
            current_chunk = Span(chunk.end, chunk.end)
    if len(current_chunk) > 0:
        new_chunks.append(current_chunk)
    return new_chunks


for chunk in coalesce_chunks(chunk_node(tree.root_node), python_code):
    print(chunk.extract(python_code))

## Use Line Numbers

Using line numbers instead of character indices. Works because Span is unit-agnostic.

In [None]:
def get_line_number(index: int, source_code: str) -> int:
    total_chars = 0
    for line_number, line in enumerate(source_code.splitlines(keepends=True), start=1):
        total_chars += len(line)
        if total_chars > index:
            return line_number - 1
    return line_number


for i, chunk in enumerate(coalesce_chunks(chunk_node(tree.root_node), python_code)):
    print(
        f"Chunk {i}: {get_line_number(chunk.start, python_code)}-{get_line_number(chunk.end, python_code)}"
    )

## Final New Algorithm

Putting it altogether (switched back to MAX_CHARS=1500):

In [None]:
from tree_sitter import Tree


def chunker(
    tree: Tree,
    source_code: bytes,
    MAX_CHARS=512 * 3,
    coalesce=50,  # Any chunk less than 50 characters long gets coalesced with the next chunk
) -> list[Span]:

    # 1. Recursively form chunks based on the last post (https://docs.sweep.dev/blogs/chunking-2m-files)
    def chunk_node(node: Node) -> list[Span]:
        chunks: list[Span] = []
        current_chunk: Span = Span(node.start_byte, node.start_byte)
        node_children = node.children
        for child in node_children:
            if child.end_byte - child.start_byte > MAX_CHARS:
                chunks.append(current_chunk)
                current_chunk = Span(child.end_byte, child.end_byte)
                chunks.extend(chunk_node(child))
            elif child.end_byte - child.start_byte + len(current_chunk) > MAX_CHARS:
                chunks.append(current_chunk)
                current_chunk = Span(child.start_byte, child.end_byte)
            else:
                current_chunk += Span(child.start_byte, child.end_byte)
        chunks.append(current_chunk)
        return chunks

    chunks = chunk_node(tree.root_node)

    # 2. Filling in the gaps
    for prev, curr in zip(chunks[:-1], chunks[1:]):
        prev.end = curr.start
    curr.start = tree.root_node.end_byte

    # 3. Combining small chunks with bigger ones
    new_chunks = []
    current_chunk = Span(0, 0)
    for chunk in chunks:
        current_chunk += chunk
        if non_whitespace_len(
            current_chunk.extract(source_code)
        ) > coalesce and "\n" in current_chunk.extract(source_code):
            new_chunks.append(current_chunk)
            current_chunk = Span(chunk.end, chunk.end)
    if len(current_chunk) > 0:
        new_chunks.append(current_chunk)

    # 4. Changing line numbers
    line_chunks = [
        Span(
            get_line_number(chunk.start, source_code),
            get_line_number(chunk.end, source_code),
        )
        for chunk in new_chunks
    ]

    # 5. Eliminating empty chunks
    line_chunks = [chunk for chunk in line_chunks if len(chunk) > 0]

    return line_chunks


for chunk in chunker(tree, python_code):
    print(chunk.extract_lines(python_code) + "\n\n====================\n\n")