In [7]:
from tree_sitter import Language, Parser

Language.build_library(f'/tmp/python.so', [f"../tree-sitter-python"]) 
language = Language("/tmp/python.so", "python")
parser = Parser()
parser.set_language(language)

In [99]:
python_example = '''
from typing import Optional, cast
from chromadb.api import API
from chromadb.config import System
from chromadb.api.types import (
    Documents,
    Embeddings,
    EmbeddingFunction,
    IDs,
    Include,
    Metadatas,
    Where,
    WhereDocument,
    GetResult,
    QueryResult,
    CollectionMetadata,
)
import chromadb.utils.embedding_functions as ef
import pandas as pd
import requests
import json
from typing import Sequence
from chromadb.api.models.Collection import Collection
import chromadb.errors as errors
from uuid import UUID


class FastAPI(API):
    def __init__(self, system: System):
        url_prefix = "https" if system.settings.chroma_server_ssl_enabled else "http"
        system.settings.require("chroma_server_host")
        system.settings.require("chroma_server_http_port")
        self._api_url = f"{url_prefix}://{system.settings.chroma_server_host}:{system.settings.chroma_server_http_port}/api/v1"
        self._telemetry_client = system.get_telemetry()

    def heartbeat(self) -> int:
        """Returns the current server time in nanoseconds to check if the server is alive"""
        resp = requests.get(self._api_url)
        raise_chroma_error(resp)
        return int(resp.json()["nanosecond heartbeat"])

    def list_collections(self) -> Sequence[Collection]:
        """Returns a list of all collections"""
        resp = requests.get(self._api_url + "/collections")
        raise_chroma_error(resp)
        json_collections = resp.json()
        collections = []
        for json_collection in json_collections:
            collections.append(Collection(self, **json_collection))

        return collections

    def create_collection(
        self,
        name: str,
        metadata: Optional[CollectionMetadata] = None,
        embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
        get_or_create: bool = False,
    ) -> Collection:
        """Creates a collection"""
        resp = requests.post(
            self._api_url + "/collections",
            data=json.dumps(
                {"name": name, "metadata": metadata, "get_or_create": get_or_create}
            ),
        )
        raise_chroma_error(resp)
        resp_json = resp.json()
        return Collection(
            client=self,
            id=resp_json["id"],
            name=resp_json["name"],
            embedding_function=embedding_function,
            metadata=resp_json["metadata"],
        )

    def get_collection(
        self,
        name: str,
        embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
    ) -> Collection:
        """Returns a collection"""
        resp = requests.get(self._api_url + "/collections/" + name)
        raise_chroma_error(resp)
        resp_json = resp.json()
        return Collection(
            client=self,
            name=resp_json["name"],
            id=resp_json["id"],
            embedding_function=embedding_function,
            metadata=resp_json["metadata"],
        )

    def get_or_create_collection(
        self,
        name: str,
        metadata: Optional[CollectionMetadata] = None,
        embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
    ) -> Collection:
        """Get a collection, or return it if it exists"""

        return self.create_collection(
            name, metadata, embedding_function, get_or_create=True
        )

    def _modify(
        self,
        id: UUID,
        new_name: Optional[str] = None,
        new_metadata: Optional[CollectionMetadata] = None,
    ) -> None:
        """Updates a collection"""
        resp = requests.put(
            self._api_url + "/collections/" + str(id),
            data=json.dumps({"new_metadata": new_metadata, "new_name": new_name}),
        )
        raise_chroma_error(resp)

    def delete_collection(self, name: str) -> None:
        """Deletes a collection"""
        resp = requests.delete(self._api_url + "/collections/" + name)
        raise_chroma_error(resp)

    def _count(self, collection_id: UUID) -> int:
        """Returns the number of embeddings in the database"""
        resp = requests.get(
            self._api_url + "/collections/" + str(collection_id) + "/count"
        )
        raise_chroma_error(resp)
        return cast(int, resp.json())

    def _peek(self, collection_id: UUID, limit: int = 10) -> GetResult:
        return self._get(
            collection_id,
            limit=limit,
            include=["embeddings", "documents", "metadatas"],
        )

    def _get(
        self,
        collection_id: UUID,
        ids: Optional[IDs] = None,
        where: Optional[Where] = {},
        sort: Optional[str] = None,
        limit: Optional[int] = None,
        offset: Optional[int] = None,
        page: Optional[int] = None,
        page_size: Optional[int] = None,
        where_document: Optional[WhereDocument] = {},
        include: Include = ["metadatas", "documents"],
    ) -> GetResult:
        """Gets embeddings from the database"""
        if page and page_size:
            offset = (page - 1) * page_size
            limit = page_size

        resp = requests.post(
            self._api_url + "/collections/" + str(collection_id) + "/get",
            data=json.dumps(
                {
                    "ids": ids,
                    "where": where,
                    "sort": sort,
                    "limit": limit,
                    "offset": offset,
                    "where_document": where_document,
                    "include": include,
                }
            ),
        )

        raise_chroma_error(resp)
        body = resp.json()
        return GetResult(
            ids=body["ids"],
            embeddings=body.get("embeddings", None),
            metadatas=body.get("metadatas", None),
            documents=body.get("documents", None),
        )

    def _delete(
        self,
        collection_id: UUID,
        ids: Optional[IDs] = None,
        where: Optional[Where] = {},
        where_document: Optional[WhereDocument] = {},
    ) -> IDs:
        """Deletes embeddings from the database"""

        resp = requests.post(
            self._api_url + "/collections/" + str(collection_id) + "/delete",
            data=json.dumps(
                {"where": where, "ids": ids, "where_document": where_document}
            ),
        )

        raise_chroma_error(resp)
        return cast(IDs, resp.json())

    def _add(
        self,
        ids: IDs,
        collection_id: UUID,
        embeddings: Embeddings,
        metadatas: Optional[Metadatas] = None,
        documents: Optional[Documents] = None,
        increment_index: bool = True,
    ) -> bool:
        """
        Adds a batch of embeddings to the database
        - pass in column oriented data lists
        - by default, the index is progressively built up as you add more data. If for ingestion performance reasons you want to disable this, set increment_index to False
        -     and then manually create the index yourself with collection.create_index()
        """
        resp = requests.post(
            self._api_url + "/collections/" + str(collection_id) + "/add",
            data=json.dumps(
                {
                    "ids": ids,
                    "embeddings": embeddings,
                    "metadatas": metadatas,
                    "documents": documents,
                    "increment_index": increment_index,
                }
            ),
        )

        raise_chroma_error(resp)
        return True

    def _update(
        self,
        collection_id: UUID,
        ids: IDs,
        embeddings: Optional[Embeddings] = None,
        metadatas: Optional[Metadatas] = None,
        documents: Optional[Documents] = None,
    ) -> bool:
        """
        Updates a batch of embeddings in the database
        - pass in column oriented data lists
        """

        resp = requests.post(
            self._api_url + "/collections/" + str(collection_id) + "/update",
            data=json.dumps(
                {
                    "ids": ids,
                    "embeddings": embeddings,
                    "metadatas": metadatas,
                    "documents": documents,
                }
            ),
        )

        resp.raise_for_status()
        return True

    def _upsert(
        self,
        collection_id: UUID,
        ids: IDs,
        embeddings: Embeddings,
        metadatas: Optional[Metadatas] = None,
        documents: Optional[Documents] = None,
        increment_index: bool = True,
    ) -> bool:
        """
        Updates a batch of embeddings in the database
        - pass in column oriented data lists
        """

        resp = requests.post(
            self._api_url + "/collections/" + str(collection_id) + "/upsert",
            data=json.dumps(
                {
                    "ids": ids,
                    "embeddings": embeddings,
                    "metadatas": metadatas,
                    "documents": documents,
                    "increment_index": increment_index,
                }
            ),
        )

        resp.raise_for_status()
        return True

    def _query(
        self,
        collection_id: UUID,
        query_embeddings: Embeddings,
        n_results: int = 10,
        where: Optional[Where] = {},
        where_document: Optional[WhereDocument] = {},
        include: Include = ["metadatas", "documents", "distances"],
    ) -> QueryResult:
        """Gets the nearest neighbors of a single embedding"""

        resp = requests.post(
            self._api_url + "/collections/" + str(collection_id) + "/query",
            data=json.dumps(
                {
                    "query_embeddings": query_embeddings,
                    "n_results": n_results,
                    "where": where,
                    "where_document": where_document,
                    "include": include,
                }
            ),
        )

        raise_chroma_error(resp)
        body = resp.json()

        return QueryResult(
            ids=body["ids"],
            distances=body.get("distances", None),
            embeddings=body.get("embeddings", None),
            metadatas=body.get("metadatas", None),
            documents=body.get("documents", None),
        )

    def reset(self) -> bool:
        """Resets the database"""
        resp = requests.post(self._api_url + "/reset")
        raise_chroma_error(resp)
        return cast(bool, resp.json())

    def persist(self) -> bool:
        """Persists the database"""
        resp = requests.post(self._api_url + "/persist")
        raise_chroma_error(resp)
        return cast(bool, resp.json())

    def raw_sql(self, sql: str) -> pd.DataFrame:
        """Runs a raw SQL query against the database"""
        resp = requests.post(
            self._api_url + "/raw_sql", data=json.dumps({"raw_sql": sql})
        )
        raise_chroma_error(resp)
        return pd.DataFrame.from_dict(resp.json())

    def create_index(self, collection_name: str) -> bool:
        """Creates an index for the given space key"""
        resp = requests.post(
            self._api_url + "/collections/" + collection_name + "/create_index"
        )
        raise_chroma_error(resp)
        return cast(bool, resp.json())

    def get_version(self) -> str:
        """Returns the version of the server"""
        resp = requests.get(self._api_url + "/version")
        raise_chroma_error(resp)
        return cast(str, resp.json())


def raise_chroma_error(resp: requests.Response) -> None:
    """Raises an error if the response is not ok, using a ChromaError if possible"""
    if resp.ok:
        return

    chroma_error = None
    try:
        body = resp.json()
        if "error" in body:
            if body["error"] in errors.error_types:
                chroma_error = errors.error_types[body["error"]](body["message"])

    except BaseException:
        pass

    if chroma_error:
        raise chroma_error

    try:
        resp.raise_for_status()
    except requests.HTTPError:
        raise (Exception(resp.text))
'''

tree = parser.parse(bytes(python_example, "utf-8"))

In [100]:
# pretty print tree recursively
from tree_sitter import Node

threshold = 1000
def pretty_print_tree(node: Node, source_code, indent=0):
    result = ''
    result += ' ' * indent + f'Type: {node.type}, Slice: {node.start_byte}:{node.end_byte}, Content: {source_code[node.start_byte:min(node.end_byte, node.start_byte + 20)]}\n'

    if node.end_byte - node.start_byte > threshold:
        for child in node.children:
            result += pretty_print_tree(child, source_code, indent + 2)

    return result

print(pretty_print_tree(tree.root_node, python_example))


Type: module, Slice: 1:12142, Content: from typing import O
  Type: import_from_statement, Slice: 1:34, Content: from typing import O
  Type: import_from_statement, Slice: 35:63, Content: from chromadb.api im
  Type: import_from_statement, Slice: 64:98, Content: from chromadb.config
  Type: import_from_statement, Slice: 99:310, Content: from chromadb.api.ty
  Type: import_statement, Slice: 311:358, Content: import chromadb.util
  Type: import_statement, Slice: 359:378, Content: import pandas as pd
  Type: import_statement, Slice: 379:394, Content: import requests
  Type: import_statement, Slice: 395:406, Content: import json
  Type: import_from_statement, Slice: 407:434, Content: from typing import S
  Type: import_from_statement, Slice: 435:488, Content: from chromadb.api.mo
  Type: import_statement, Slice: 489:521, Content: import chromadb.erro
  Type: import_from_statement, Slice: 522:543, Content: from uuid import UUI
  Type: class_definition, Slice: 546:11544, Content: class FastA

In [125]:
from dataclasses import dataclass

@dataclass
class Chunk:
    start: int
    end: int

    def extract(self, s: str) -> str:
        # return s[self.start:self.end]
        return "\n".join(s.splitlines()[self.start:self.end])

    def __add__(self, other):
        if isinstance(other, int):
            return Chunk(self.start + other, self.end + other)
        elif isinstance(other, Chunk):
            return Chunk(self.start, other.end)
        else:
            raise NotImplementedError()
    
    def __len__(self):
        return self.end - self.start

In [92]:
def naive_chunker(source_code: str, chunk_lines: int = 20, start_position: int = 0) -> list[Chunk]:
    source_lines = source_code.split('\n')
    num_lines = len(source_lines)

    chunks = []

    start = 0
    start_line = 0
    while start_line < num_lines:
        end_line = min(start_line + chunk_lines, num_lines)
        chunk_string = '\n'.join(source_code[start_line:end_line]) + "\n"
        if len(chunk_string) > 2:
            chunks.append(Chunk(start + start_position, start + len(chunk_string) + start_position))
            start += len(chunk_string)
        start_line += chunk_lines
    return chunks

In [126]:
def flatten(node: Node, source_code, max_chunk_size=1500) -> list[tuple[int, int]]:
    if node.end_byte - node.start_byte > max_chunk_size:
        result = []
        for child in node.children:
            result += flatten(child, source_code)
        for prev, curr in zip(result[:-1], result[1:]): # ensures there are no gaps
            prev[1] = curr[0]
        return result
    else:
        return [[node.start_byte, node.end_byte]]

def char_number_to_line_number(index: int, source_code: str) -> int:
    lines = source_code.splitlines(keepends=True)
    total_chars = 0
    line_number = 0
    while total_chars <= index:
        total_chars += len(lines[line_number])
        line_number += 1
    return line_number - 1

flattened = flatten(tree.root_node, python_example)
flattened_line_numbers = [Chunk(char_number_to_line_number(start, python_example), char_number_to_line_number(end, python_example)) for start, end in flattened]
for chunk in flattened_line_numbers:
    print(chunk.extract(python_example) + "\n\n")

from typing import Optional, cast


from chromadb.api import API


from chromadb.config import System


from chromadb.api.types import (
    Documents,
    Embeddings,
    EmbeddingFunction,
    IDs,
    Include,
    Metadatas,
    Where,
    WhereDocument,
    GetResult,
    QueryResult,
    CollectionMetadata,
)


import chromadb.utils.embedding_functions as ef


import pandas as pd


import requests


import json


from typing import Sequence


from chromadb.api.models.Collection import Collection


import chromadb.errors as errors


from uuid import UUID













class FastAPI(API):


    def __init__(self, system: System):
        url_prefix = "https" if system.settings.chroma_server_ssl_enabled else "http"
        system.settings.require("chroma_server_host")
        system.settings.require("chroma_server_http_port")
        self._api_url = f"{url_prefix}://{system.settings.chroma_server_host}:{system.settings.chroma_server_http_port}/api/v1"
        self._telemetry_client = 

In [129]:
def chunker(source_code: str, max_chunk_size=1500) -> list[Chunk]:
    tree = parser.parse(bytes(source_code, "utf-8"))
    flattened = [Chunk(char_number_to_line_number(start, source_code), char_number_to_line_number(end, source_code)) for start, end in flatten(tree.root_node, source_code)]

    chunks = []
    current_chunk = Chunk(0, 0)
    lines = source_code.splitlines(keepends=True)
    for chunk in flattened:
        if sum([len(line) for line in lines[current_chunk.start:current_chunk.end]])\
            + sum([len(line) for line in lines[chunk.start:chunk.end]]) > max_chunk_size:
            chunks.append(current_chunk)
            current_chunk = chunk
        else:
            current_chunk += chunk
    if current_chunk:
        chunks.append(current_chunk)
    return chunks

for chunk in chunker(python_example):
    # print(chunk)
    print(chunk.extract(python_example) + "\n\n\n")



from typing import Optional, cast
from chromadb.api import API
from chromadb.config import System
from chromadb.api.types import (
    Documents,
    Embeddings,
    EmbeddingFunction,
    IDs,
    Include,
    Metadatas,
    Where,
    WhereDocument,
    GetResult,
    QueryResult,
    CollectionMetadata,
)
import chromadb.utils.embedding_functions as ef
import pandas as pd
import requests
import json
from typing import Sequence
from chromadb.api.models.Collection import Collection
import chromadb.errors as errors
from uuid import UUID


class FastAPI(API):
    def __init__(self, system: System):
        url_prefix = "https" if system.settings.chroma_server_ssl_enabled else "http"
        system.settings.require("chroma_server_host")
        system.settings.require("chroma_server_http_port")
        self._api_url = f"{url_prefix}://{system.settings.chroma_server_host}:{system.settings.chroma_server_http_port}/api/v1"
        self._telemetry_client = system.get_telemetry()

    def he

In [75]:
for chunk in naive_chunker(python_example, 20):
    print(len(chunk))
    # print(chunk + "\n\n\n")

40
0


ValueError: __len__() should return >= 0