In [None]:
!pip install tree-sitter -q
!pip install tree-sitter-language-pack -q
!pip install tree-sitter-javascript -q
!pip install tree-sitter-typescript -q
!pip install hnswlib -q

# 1. Parse Code

In [None]:
from tree_sitter import Language, Parser
import tree_sitter_typescript
import tree_sitter_javascript

JS_LANGUAGE = Language(tree_sitter_javascript.language())
TS_LANGUAGE = Language(tree_sitter_typescript.language_typescript())

parser = Parser(JS_LANGUAGE)

code = """
import('./bootstrap')
export {}
"""

xtree = parser.parse(bytes(code, "utf8"))

# 2. Explore parsed tree structure

In [None]:
root = xtree.root_node

for child in root.children:
    print(child.type, " -> ", code[child.start_byte:child.end_byte])

# 3. Chunking Code with Tree-Sitter

In [None]:
terminal = [
    'import_statement',
    'lexical_declaration',
    'expression_statement',
    'export_statement'
]

def extract_subtree(subtree_root):
  queue = [subtree_root]
  subtree_nodes = []
  ignore_types = ["\n"]
  while queue:
    current_node = queue.pop(0)
    for child in current_node.children:
      child_type = str(child.type)
      if child_type not in ignore_types:
        queue.append(child)
      if child_type in terminal:
        subtree_nodes.append(child)
  return subtree_nodes

def extract_subtrees(tree):
  root = tree.root_node
  all_subtrees = []
  queue = [root]
  while queue:
    current_node = queue.pop(0)
    if str(current_node.type) in terminal:
      all_subtrees.append(current_node)
    else:
      subtree = extract_subtree(current_node)
      all_subtrees.extend(subtree)
      children = [x for x in current_node.children]
      queue.extend(children)
  return all_subtrees



In [None]:
subtrees = extract_subtrees(xtree)

print(subtrees)

# 4. Convert AST Nodes to Text for Embeddings

In [None]:
src_texts = []
for subtree in subtrees:
  if code[subtree.start_byte:subtree.end_byte] not in src_texts:
    src_texts.append(code[subtree.start_byte:subtree.end_byte])

In [None]:
import torch
from transformers import AutoModel, AutoTokenizer
### from optimum.bettertransformer import BetterTransformer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embedding_model = "Salesforce/codet5p-110m-embedding"

tokenizer = AutoTokenizer.from_pretrained(embedding_model, trust_remote_code=True)
model = AutoModel.from_pretrained(embedding_model, trust_remote_code=True).to(device)
### model = AutoModel.from_pretrained(embedding_model, trust_remote_code=True, torch_dtype=torch.float16).to(device)
### model.config.model_type = 't5'
### model = model.to_bettertransformer()
### model.eval()

"""
from transformers import AutoModel, AutoTokenizer

checkpoint = "Salesforce/codet5p-110m-embedding"
device = "cuda"  # for GPU usage or "cpu" for CPU usage

tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
model = AutoModel.from_pretrained(checkpoint, trust_remote_code=True).to(device)

inputs = tokenizer.encode("def print_hello_world():\tprint('Hello World!')", return_tensors="pt").to(device)
embedding = model(inputs)[0]
print(f'Dimension of the embedding: {embedding.size()[0]}, with norm={embedding.norm().item()}')
# Dimension of the embedding: 256, with norm=1.0
print(embedding)
"""

def get_embedding(texts, max_length=2048):
  ### inputs = tokenizer(texts, return_tensors="pt", max_length=max_length, padding='max_length', truncation=True).to(device)
  with torch.no_grad():
    inputs = tokenizer.encode(texts, return_tensors="pt").to(device)
    return model(inputs)[0]
  """
  with torch.no_grad():
    outputs = model(**inputs)
    return outputs.cpu().detach()
  """

embeddings = []
for src_text in src_texts:
  embedding = get_embedding(src_text)
  embeddings.append(embedding)

query_embedding = get_embedding("find code that import bootstrap")


In [None]:
import numpy as np

print(type(embeddings)) ## list of tensor
print(len(embeddings))
print(type(query_embedding)) ## tensor
# print(embeddings[0])

list_of_arrays = [emb.numpy() for emb in embeddings]
src_emb = np.stack(list_of_arrays)


print(type(src_emb))
print(len(src_emb))
print(src_emb.shape[1])
print(src_emb.shape)
print(src_emb)

# 5. Storing and Retrieving Code Chunks

In [None]:
import hnswlib
import numpy as np

dim = src_emb.shape[1]
print(f"Dimension: {dim}")

num_elements = len(src_emb)
print(f"Number of elements: {num_elements}")

index = hnswlib.Index(space='cosine', dim=dim)
index.init_index(max_elements=num_elements, ef_construction=200, M=16)
index.add_items(src_emb, np.arange(num_elements))

# emb is the embedding of our query code chunk
# k=5 returns 5 most similar code chunks
labels, distances = index.knn_query(query_embedding.numpy(), k=1)

print(f"Nearest neighbors' labels: {labels}")
print(f"Distance: {distances}")

print(f"Retrieved documents: {[src_texts[i] for i in labels[0]]}")