# Code Retrieval using Embeddings

## Load Embedding Model

We use [CodeT5+](https://huggingface.co/Salesforce/codet5p-110m-embedding) as embedding model.

* Maximum input: 512 tokens
* Output dimensions: 256

In [None]:
from transformers import AutoModel, AutoTokenizer
import torch

In [None]:
gpu = torch.device('cuda:0')
model_id = "Salesforce/codet5p-110m-embedding"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id, device_map=gpu, torch_dtype=torch.bfloat16, trust_remote_code=True)

In [None]:
def embed(code):
    """Use the model to embed the given code"""
    inputs = tokenizer.encode(code, return_tensors='pt').to(gpu)
    output = model(inputs)[0]
    
    # convert to "CPU-friendly" datatype (GPU has bfloat16, which is incompatible with x86 code) and pull to CPU
    return output.detach().to(torch.float32).cpu().numpy()

In [None]:
# Let's check out an example embedding:
example_code = '''def fib(n):
    return fib(n - 1) + fib(n - 2) if n > 1 else n
    '''
embed(example_code)

## üë©‚Äçüíª Build a Simple Code "Database"

We need some code to retrieve. For this example, we use all functions in the [Flask](https://github.com/pallets/flask) web framework, extracting them using GitPyton and tree-sitter.

In [None]:
import git
import autopep8
import tree_sitter_python as tspython
from tree_sitter import Language, Parser
import numpy as np

PY_LANGUAGE = Language(tspython.language())

In [None]:
def repo_to_functions(repo_location):
    """Extract all Python function definitions from the given Git repository, identifying each by <repo>:<file_name>:<function_name>"""

    # fetch top commit's tree
    repo = git.Repo(repo_location)
    tree = repo.head.commit.tree

    # parser and query for Python
    parser = Parser()
    parser.language = PY_LANGUAGE
    query = PY_LANGUAGE.query('''(function_definition) @func''')

    # read all .py files
    files = [(item.name, item.data_stream.read())
             for item in tree.list_traverse()
             if item.type == 'blob'
             and item.name.endswith('.py')]

    def function_name(node):
        return node.child_by_field_name('name').text.decode('utf-8')
    
    # query all functions in all files. We use <repo>:<file_name>:<function_name> as "ID"
    functions = {f'{repo_location}:{name}:{function_name(node)}' : autopep8.fix_code(node.text)
                 for name, file in files
                 for node, _ in query.captures(parser.parse(file).root_node)}

    return functions

In [None]:
!git clone --bare https://github.com/pallets/flask.git

# related repository suggestions:
# - https://github.com/pallets/werkzeug.git
# - https://github.com/pallets/jinja.git


In [None]:
%%time
# build a dictionary of functions
functions = repo_to_functions('./flask.git')
# add other repositories:
#functions.update(repo_to_functions('<OTHER REPO>'))


In [None]:
print(f'{len(functions)} functions extracted')

## üìñ Populate a Simple "Vector Database"

* We compute the code embedding for each function, storing it under the key (function name).
* The retriever embeds the query and ranks each item according to cosine similarity.
* ‚ö†Ô∏è For larger databases, use a real vector database or specialized frameworks (e.g., [LlamaIndex](https://docs.llamaindex.ai/en/stable/))

In [None]:
%%time
# compute embeddings for each function
embeddings = {name : embed(code) for name, code in functions.items()}

In [None]:
def retrieve(embeddings, query):
    """Rank embedded items by their similarity to the query"""
    
    query_embed = embed(query)
    similarities = [(name, np.dot(query_embed, embedding))
                    for name, embedding in embeddings.items()]
    return sorted(similarities, key=lambda item: item[1], reverse=True)


## Test Retrieval

In [None]:
top_10 = retrieve(embeddings, "# test whether a user can log in")[:10]

In [None]:
# Print top 10 results:
def print_retrieval_results(ranked_results):
    for index, (name, similarity) in enumerate(ranked_results):
        print('=' * 80)
        print(f'{index + 1}: {name} ({similarity:.2f})')
        print('-' * 80)
        print(functions[name])
print_retrieval_results(top_10)
    

## Visualize Embedding Space

We map each 256-dimensional vector to a 2D vector using PCA and plot the result.
* Tests are red
* Non-test functions are blue

In [None]:
from sklearn.decomposition import PCA
from matplotlib import pyplot as plt

In [None]:
pca = PCA(2)
keys = list(embeddings.keys())
vectors = list(embeddings.values())

# do the PCA transformation
projected = pca.fit_transform(vectors)

# define a color for each data point based on the keys
colors = ['red' if 'test_' in k else 'blue' for k in keys]

In [None]:
plt.scatter(x = projected[:, 0], y=projected[:,1], color=colors)
plt.show()

## Generation-augmented Retrieval (GAR)

* Generation-augmented retrieval expands the user's query using an LLM.
* Retrieval compares against the embedding of the generated query
* Here, we do **code completion** to obtain an example code from **natural language**, which should make it **easier to match against other code** because they now share the same qualities.

In [None]:
from transformers import GemmaTokenizer, AutoModelForCausalLM
gen_model_id = "google/codegemma-1.1-2b"
gen_tokenizer = GemmaTokenizer.from_pretrained(gen_model_id)
gen_model = AutoModelForCausalLM.from_pretrained(gen_model_id, device_map=gpu, torch_dtype=torch.bfloat16)

In [None]:
# A standard greedy generation helper
def generate(prompt, max_new_tokens=128):
    inputs = gen_tokenizer.encode(prompt, return_tensors='pt').to(gpu)
    outputs = gen_model.generate(inputs, max_new_tokens=max_new_tokens)
    return gen_tokenizer.decode(outputs[0])

def generate_retrieve(embeddings, prompt):
    generated = generate(prompt)
    return retrieve(embeddings, generated), generated

In [None]:
top_10, generated = generate_retrieve(embeddings, "# A flask test case to verify user login functionality:\n")[:10]

print('Generated query:')
print(generated)

print_retrieval_results(top_10)

In [None]:
torch.cuda.empty_cache()