# Updates on the Chunking Algorithm
This notebook is for the blog on improvements to our chunking algorithm

In [65]:
import re
from tree_sitter_languages import get_language, get_parser

language = get_language('python')
parser = get_parser('python')

## Meet the Span

In [89]:
from __future__ import annotations
from dataclasses import dataclass, field

@dataclass
class Span:
    start: int = 0
    end: int = field(default=lambda self: self.start)

    def extract(self, s: str) -> str:
        return "\n".join(s.splitlines()[self.start:self.end])

    def __add__(self, other) -> Span:
        if isinstance(other, int):
            return Span(self.start + other, self.end + other)
        elif isinstance(other, Span):
            return Span(self.start, other.end)
        else:
            raise NotImplementedError()

    def __len__(self) -> int:
        return self.end - self.start

The example code we're gonna use in this guide:

In [64]:
python_code = r'''
import io
import os
import zipfile

import openai
import requests
from loguru import logger

from sweepai.core.gha_extraction import GHAExtractor
from sweepai.events import CheckRunCompleted
from sweepai.handlers.on_comment import on_comment
from sweepai.utils.config.client import SweepConfig, get_gha_enabled
from sweepai.utils.github_utils import get_github_client, get_token

openai.api_key = os.environ.get("OPENAI_API_KEY")

log_message = """GitHub actions yielded the following error. 

{error_logs}

This is likely a linting or type-checking issue with the source code but if you are updating the GitHub Actions or versioning, this could be an issue with the GitHub Action yaml files."""

def download_logs(repo_full_name: str, run_id: int, installation_id: int):
    headers = {
        "Accept": "application/vnd.github+json",
        "Authorization": f"Bearer {get_token(installation_id)}",
        "X-GitHub-Api-Version": "2022-11-28"
    }
    response = requests.get(f"https://api.github.com/repos/{repo_full_name}/actions/runs/{run_id}/logs",
                            headers=headers)

    logs_str = ""
    if response.status_code == 200:
        zip_file = zipfile.ZipFile(io.BytesIO(response.content))
        for file in zip_file.namelist():
            if "/" not in file:
                with zip_file.open(file) as f:
                    logs_str += f.read().decode("utf-8")
    else:
        logger.warning(f"Failed to download logs for run id: {run_id}")
    return logs_str


def clean_logs(logs_str: str):
    log_list = logs_str.split("\n")
    truncated_logs = [log[log.find(" ") + 1:] for log in log_list]
    patterns = [
        # for docker
        "Already exists",
        "Pulling fs layer",
        "Waiting",
        "Download complete",
        "Verifying Checksum",
        "Pull complete",
        # For github
        "remote: Counting objects",
        "remote: Compressing objects:",
        "Receiving objects:",
        "Resolving deltas:"
    ]
    return "\n".join([log.strip() for log in truncated_logs if not any(pattern in log for pattern in patterns)])


def on_check_suite(request: CheckRunCompleted):
    logger.info(f"Received check run completed event for {request.repository.full_name}")
    g = get_github_client(request.installation.id)
    repo = g.get_repo(request.repository.full_name)
    if not get_gha_enabled(repo):
        logger.info(f"Skipping github action for {request.repository.full_name} because it is not enabled")
        return None
    pr = repo.get_pull(request.check_run.pull_requests[0].number)
    num_pr_commits = len(list(pr.get_commits()))
    if num_pr_commits > 20:
        logger.info(f"Skipping github action for PR with {num_pr_commits} commits")
        return None
    logger.info(f"Running github action for PR with {num_pr_commits} commits")
    logs = download_logs(
        request.repository.full_name,
        request.check_run.run_id,
        request.installation.id
    )
    if not logs:
        return None
    logs = clean_logs(logs)
    extractor = GHAExtractor()
    logger.info(f"Extracting logs from {request.repository.full_name}, logs: {logs}")
    problematic_logs = extractor.gha_extract(logs)
    if problematic_logs.count("\n") > 15:
        problematic_logs += "\n\nThere are a lot of errors. This is likely a larger issue with the PR and not a small linting/type-checking issue."
    comments = list(pr.get_issue_comments())
    if len(comments) >= 2 and problematic_logs == comments[-1].body and comments[-2].body == comments[-1].body:
        comment = pr.as_issue().create_comment(log_message.format(error_logs=problematic_logs) + "\n\nI'm getting the same errors 3 times in a row, so I will stop working on fixing this PR.")
        logger.warning("Skipping logs because it is duplicated")
        raise Exception("Duplicate error logs")
    print(problematic_logs)
    comment = pr.as_issue().create_comment(log_message.format(error_logs=problematic_logs))
    on_comment(
        repo_full_name=request.repository.full_name,
        repo_description=request.repository.description,
        comment=problematic_logs,
        pr_path=None,
        pr_line_position=None,
        username=request.sender.login,
        installation_id=request.installation.id,
        pr_number=request.check_run.pull_requests[0].number,
        comment_id=comment.id,
        repo=repo,
    )
    return {"success": True}
'''

expression_statement:432-696
  assignment:432-696
    string:446-696
function_definition:698-1502
  block:777-1502
    expression_statement:777-953
      assignment:777-953
        dictionary:787-953
    expression_statement:958-1103
      assignment:958-1103
        call:969-1103
    if_statement:1127-1482
      block:1167-1400
        for_statement:1232-1400
function_definition:1505-2107
  block:1540-2107
    expression_statement:1643-1994
      assignment:1643-1994
        list:1654-1994
function_definition:2110-4426
  block:2162-4426
    if_statement:2355-2512
      block:2393-2512
    if_statement:3212-3397
      block:3258-3397
        expression_statement:3258-3397
          augmented_assignment:3258-3397
    if_statement:3447-3859
      block:3563-3859
        expression_statement:3563-3746
          assignment:3563-3746
            call:3573-3746
              argument_list:3601-3746
                binary_operator:3602-3745
    expression_statement:3984-4397
      call:3984-4

Let's first visualize the syntax tree.

In [None]:
tree = parser.parse(python_code.encode("utf-8"))

def pretty_node(node):
    return f"{node.type}:{node.start_byte}-{node.end_byte}"

def print_tree(node, indent=""):
    if len(re.sub("\s", "", node.text.decode("utf-8"))) < 100:
        return
    print(indent + pretty_node(node))
    for child in node.children:
        print_tree(child, indent=indent + "  ")

for child in tree.root_node.children:
    print_tree(child)

We can see that it doesn't actually line up:

```
expression_statement:432-696 : log_message = "...
  assignment:432-696 : log_message = "...
    string:446-696 : """GitHub actio...
function_definition:698-1502 : def download_lo...
  block:777-1502 : headers = {    ...
    expression_statement:777-953 : headers = {    ...
      assignment:777-953 : headers = {    ...
        dictionary:787-953 : {         "Acce...
    expression_statement:958-1103 : response = requ...
      assignment:958-1103 : response = requ...
```

Notice that the “expression_statement” ends on byte 696 and “function_definition” starts on byte 698, skipping a byte.

In [67]:
def connect_chunks(chunks: list[Span]):
    for prev, curr in zip(chunks[:-1], chunks[1:]):
        prev.end = curr.start
    return chunks

## Coalescing
Recall this is the old algo (with some fixes).

In [94]:
from tree_sitter import Node

@dataclass
class MockNode:
    start_byte: int = 0
    end_byte: int = field(default_factory=lambda self: self.start_byte + 1)
    children: list[MockNode] = field(default_factory=list)

def chunk_node(
    node: Node, 
    text: str, 
    MAX_CHARS: int = 600
) -> list[str]:
    new_chunks = []
    current_chunk = ""
    node_children = node.children + [MockNode(node.end_byte, node.end_byte, [])]

    for child, next_child in zip(node.children[:-1], node.children[1:]):
        if child.end_byte - child.start_byte > MAX_CHARS:
            new_chunks.append(current_chunk)
            current_chunk = ""
            new_chunks.extend(chunk_node(child, text, MAX_CHARS))
        elif child.end_byte - child.start_byte + len(current_chunk) > MAX_CHARS:
            new_chunks.append(current_chunk)
            current_chunk = text[child.start_byte: next_child.start_byte]
        else:
            current_chunk += text[child.start_byte: next_child.start_byte]
    
    # # Handle last child
    # if next_child.end_byte - next_child.start_byte > MAX_CHARS:
    #     new_chunks.append(current_chunk)
    #     new_chunks.extend(chunk_node(next_child, text, MAX_CHARS))
    # elif next_child.end_byte - next_child.start_byte + len(current_chunk) > MAX_CHARS:
    #     new_chunks.append(current_chunk)
    #     new_chunks.append(next_child.text.decode("utf-8"))
    # else:
    #     new_chunks.append(current_chunk + next_child.text.decode("utf-8"))
    # new_chunks.append(current_chunk)
    return new_chunks

for chunk in chunk_node(tree.root_node, python_code):
    print(chunk + "\n\n====================\n\n")

TypeError: tree_sitter.Node() takes no arguments

Firstly, using Span's we can clean up the code a bit.

In [None]:
from tree_sitter import Node

def chunk_node(
    node: Node, 
    text: str, 
    MAX_CHARS: int = 600
) -> list[Span]:
    new_chunks: list[Span] = []
    current_chunk: Span = Span()
    for child in node.children:
        if child.end_byte - child.start_byte > MAX_CHARS:
            new_chunks.append(current_chunk)
            current_chunk = Span(child.end_byte + 1)
            new_chunks.extend(chunk_node(child, text, MAX_CHARS))
        elif child.end_byte - child.start_byte + len(current_chunk) > MAX_CHARS:
            new_chunks.append(current_chunk)
            current_chunk = Span(child.start_byte, child.end_byte) 
        else:
            current_chunk += Span(child.start_byte, child.end_byte)
    
    # Handle last child
    if next_child.end_byte - next_child.start_byte > MAX_CHARS:
        new_chunks.append(current_chunk)
        new_chunks.extend(chunk_node(next_child, text, MAX_CHARS))
    elif next_child.end_byte - next_child.start_byte + len(current_chunk) > MAX_CHARS:
        new_chunks.append(current_chunk)
        new_chunks.append(next_child.text.decode("utf-8"))
    else:
        new_chunks.append(current_chunk + next_child.text.decode("utf-8"))
    new_chunks.append(current_chunk)
    return new_chunks

for chunk in chunk_node(tree.root_node, python_code):
    print(chunk + "\n\n====================\n\n")