# Customizing codebase embeddings with a projection matrix

## Goal

The outcome of this exercise is to learn a projection matrix that tailors embeddings for a codebase retrieval use case, and then measure the improvement in retrieval quality.

The notebook is mostly filled out, but has a series of small gaps that you will need to fill in (everywhere you see a "TODO" comment):
- Define the similarity functions (both basic and with projection matrix)
- Define a suitable loss function
- Construct examples for training from the pre-existing dataset
- Complete the training loop code
- Finish the retrieval function logic
- Evaluate the improvement in retrieval quality

## Background

A basic retrieval augmented generation (RAG) system will typically use embeddings to represent a set of documents that are to be searched over. Then the user input can also be converted to an embedding, and the system will use the dot product of the two embeddings to determine the relevance of the input to the documents in the database.

Many embedding models are "symmetric", which means that they treat user input text and documents (e.g. code snippets) in the same way. It might be preferable to calculate the embedding differently ("asymmetrically") for the user input because is is a fundamentally different type of text.

One way of doing this is to use the same embedding model, and then apply a matrix multiplication to the embedding of the user input. What we'll try to do here is find such a matrix that can improve retrieval quality.

## Environment

We recommend using a virtual environment to install the necessary packages.

```bash
python -m venv env
source env/bin/activate
```

### Install packages

```bash
pip install -r requirements.txt
```

## Setup

Here we generate a sample embedding with `sentence_transformers`

In [75]:
from openai import OpenAI
import os
from dotenv import load_dotenv
import numpy as np
import torch
from sentence_transformers import SentenceTransformer

# Load a pre-trained model (this will be slow the first time)
model = SentenceTransformer("all-MiniLM-L12-v2")

def embed(text):
    embedding = model.encode([text])[0]
    return torch.tensor(embedding, dtype=torch.float32)

embedding = embed("Hello world")
dim = len(embedding)

print(f"Embedding dimension: {dim}")
print(f"Embedding: [{embedding[0]}, {embedding[1]}, ..., {embedding[-1]}]")


Embedding dimension: 384
Embedding: [-0.07597316801548004, -0.005261966027319431, ..., 0.03495463728904724]


## Similarity

First, we'll define our definition of similarity. This can be calculated using a dot product between two embeddings. For example, if we were trying to find the similarity between a user input $x_i$ and a code snippet $x_c$, then the similarity would be

$$h(x_i, x_c) = e(x_i) \cdot e(x_c)$$

Fill out the function below:

In [76]:
# Define the similarity function using torch
def similarity(x_i, x_c):
    embedded_x_i = model.encode(x_i)
    embedded_x_i = torch.tensor(embedded_x_i, dtype=torch.float32)
    embedded_x_c = model.encode(x_c)
    embedded_x_c = torch.tensor(embedded_x_c, dtype=torch.float32)

    similarity = torch.dot(embedded_x_i, embedded_x_c)
    return similarity

# Calculate the similarity between two strings
x_i = "Where in the codebase do we do auth?"
x_c_1 = "```python\n# Authentication\ndef authenticate(username, password):\n    # Code to authenticate the user\n```"
x_c_2 = "function sum(a, b) {\n    return a + b;\n}"

similarity1 = similarity(x_i, x_c_1)
similarity2 = similarity(x_i, x_c_2)
print(f"Similarity 1: {similarity1}")
print(f"Similarity 2: {similarity2}")

Similarity 1: 0.4509845972061157
Similarity 2: 0.03275971859693527


## Similarity with projection matrix

Next, we'll calculate similarity using the projection matrix

$$h_\theta(x_i, x_c) = e(x_c) \theta e(x_i)$$

Fill in the function below:

In [77]:
def similarity_with_projection(x_i, x_c, P):
    embedded_x_i = model.encode(x_i)
    embedded_x_c = model.encode(x_c)

    embedded_x_i = torch.tensor(embedded_x_i, dtype=torch.float32)
    embedded_x_c = torch.tensor(embedded_x_c, dtype=torch.float32).squeeze()

    projected = torch.matmul(P, embedded_x_i)
    similarity = torch.dot(projected, embedded_x_c)
    return similarity

# Generate a dim by dim random matrix
P_random = torch.randn(dim, dim, dtype=torch.float32)
print(P_random)

# Calculate the similarity with the random projection matrix
similarity_with_projection1 = similarity_with_projection(x_i, x_c_1, P_random)
similarity_with_projection2 = similarity_with_projection(x_i, x_c_2, P_random)
print(f"Similarity with projection 1: {similarity_with_projection1}")
print(f"Similarity with projection 2: {similarity_with_projection2}")

tensor([[-0.7340,  0.7239, -0.3426,  ..., -0.7377, -0.2822, -0.8234],
        [ 0.2058, -1.9054, -0.5074,  ...,  0.1905,  0.5809,  1.4301],
        [ 0.2523, -0.1437, -0.4516,  ..., -0.3498, -2.2366, -1.4045],
        ...,
        [ 0.8269, -0.6756, -1.3165,  ..., -0.0425,  0.5219,  1.0697],
        [ 0.7656,  1.3402, -1.3406,  ...,  0.9046,  0.8978, -1.0893],
        [-0.7491, -0.7929, -0.2276,  ...,  0.6324,  0.1920,  1.0543]])
Similarity with projection 1: 0.13125662505626678
Similarity with projection 2: 1.7634925842285156


## Load dataset

To train and test a matrix that is more helpful than the random one above, we will use a pre-existing dataset, which includes a list of (question, relevant code snippets) pairs, which happen to have been generated by a language model.

In [78]:
# Load the dataset from XML file (dataset.xml)

import xml.etree.ElementTree as ET
from dataclasses import dataclass
from typing import List

@dataclass 
class Example:
    user_input: str
    snippets: List[str]

class DatasetParser:
    def __init__(self, xml_file: str):
        self.tree = ET.parse(xml_file)
        self.root = self.tree.getroot()

    def parse(self) -> List[Example]:
        examples = []
        
        for example in self.root.findall('example'):
            user_input = example.find('user_input').text
            snippets_list = []
            
            for snippet in example.find('snippets').findall('snippet'):
                # Extract code and filename from the snippet text
                snippet_text = snippet.text.strip()
                
                # Parse the filename from the code block header
                first_line = snippet_text.split('\n')[0]
                filename = first_line.split(' ')[1] if len(first_line.split(' ')) > 1 else None
                
                # Remove the code block markers and get just the code
                code_lines = snippet_text.split('\n')[1:-1]
                code = '\n'.join(code_lines)
                
                snippets_list.append(code)
                
            examples.append(Example(
                user_input=user_input,
                snippets=snippets_list
            ))
            
        return examples


parser = DatasetParser('dataset.xml')
dataset = parser.parse()


In [79]:
print(dataset)

[Example(user_input='How do we handle password reset flows?', snippets=['def initiate_password_reset(email):\n    token = generate_reset_token()\n    send_reset_email(email, token)\n    store_reset_token(email, token, expiry=24*hours)\n    return True\n\ndef validate_reset_token(token, new_password):\n    if is_token_valid(token):\n        user = get_user_by_token(token)\n        update_password(user, new_password)\n        invalidate_token(token)\n        return True\n    return False']), Example(user_input='Where is the user registration logic implemented?', snippets=['export class UserService {\n  async register(userData: RegisterDTO): Promise<User> {\n    const existingUser = await this.userRepo.findByEmail(userData.email);\n    if (existingUser) {\n      throw new DuplicateUserError();\n    }\n    \n    const hashedPassword = await bcrypt.hash(userData.password);\n    return this.userRepo.create({\n      ...userData,\n      password: hashedPassword\n    });\n  }\n}']), Example(use

## Construct examples

Convert the dataset into a set of examples that can be used to train the projection matrix. These should include both examples of input/snippet pairs where the snippet is relevant, and pairs where the snippet is not relevant.

In [80]:
# Next, you should generate a list of positive and negative pairs from the dataset
# These will be used to train the matrix
import random

# TODO: Create example pairs from the dataset
example_pairs = []  # list of tuples (user input, code snippet, 1 if snippet is relevant to user input else 0)

# creates equal amount of 1/0 pairs
for example in dataset:
    example_pairs.append((example.user_input, example.snippets, 1))
    # irrelevant
    others = [ex for ex in dataset if ex != example]
    rexample = random.choice(others)
    example_pairs.append((example.user_input, random.choice(rexample.snippets), 0))


In [81]:
print(example_pairs)

[('How do we handle password reset flows?', ['def initiate_password_reset(email):\n    token = generate_reset_token()\n    send_reset_email(email, token)\n    store_reset_token(email, token, expiry=24*hours)\n    return True\n\ndef validate_reset_token(token, new_password):\n    if is_token_valid(token):\n        user = get_user_by_token(token)\n        update_password(user, new_password)\n        invalidate_token(token)\n        return True\n    return False'], 1), ('How do we handle password reset flows?', "module.exports = {\n  development: {\n    client: 'postgresql',\n    connection: {\n      host: process.env.DB_HOST,\n      database: process.env.DB_NAME,\n      user: process.env.DB_USER,\n      password: process.env.DB_PASSWORD\n    },\n    pool: {\n      min: 2,\n      max: 10\n    }\n  }\n}", 0), ('Where is the user registration logic implemented?', ['export class UserService {\n  async register(userData: RegisterDTO): Promise<User> {\n    const existingUser = await this.userR

In [82]:
# Here we split the example pairs into training and validation sets
np.random.shuffle(example_pairs)
split_index = int(0.8 * len(example_pairs))
train_pairs = example_pairs[:split_index]
val_pairs = example_pairs[split_index:]

print(f"Number of training pairs: {len(train_pairs)}")
print(f"Number of validation pairs: {len(val_pairs)}")

Number of training pairs: 32
Number of validation pairs: 8


## Define a loss function

With a model to calculate similarity, and a dataset of positive and negative examples, we're almost ready to train. The last thing we need is a loss function. Design a loss function that is suitable for this use case.

In [83]:
import numpy
import torch.nn as nn
# using binary cross entropy because task is binary classification
def loss_func(predictions, targets):
    # Error: cant call numpy() on tensor that requires grad
    # loss = -np.mean(targets * np.log(predictions) + (1 - targets) * np.log(1 - predictions))
    loss = nn.BCEWithLogitsLoss(predictions, targets)
    return loss

## Train the projection matrix

The entire training loop has been set up, except for a couple of lines to calculate the prediction given an example pair and to get $y$, which will then be used together to calculate the loss.

In [89]:
import torch.optim as optim

#added
criterion = nn.BCEWithLogitsLoss()

# Initialize the projection matrix P
P = torch.randn(
    dim, dim, requires_grad=True
)

# Set hyperparameters
lr = 0.1
num_epochs = 25
optimizer = optim.Adam([P], lr=lr)
epochs, types, losses, accuracies, matrices = [], [], [], [], []

for epoch in range(num_epochs):
    # Reset gradients
    optimizer.zero_grad()

    # Iterate through training pairs
    for pair in train_pairs:
        prediction = similarity_with_projection(pair[0], pair[1], P)
        prediction = prediction.unsqueeze(0)
        y = torch.tensor(pair[2], dtype=torch.float32).unsqueeze(0)

        loss = criterion(prediction, y)
        loss.backward()
    
    # Update weights using Adam optimizer
    optimizer.step()

    # Calculate validation loss
    val_loss = 0
    for pair in val_pairs:
        prediction = similarity_with_projection(pair[0], pair[1], P)
        prediction = prediction.unsqueeze(0)
        y = torch.tensor(pair[2], dtype=torch.float32).unsqueeze(0)
        
        val_loss += criterion(prediction, y)

    print(f"Epoch {epoch}/{num_epochs}: validation loss: {val_loss.item() / len(val_pairs)}")

Epoch 0/25: validation loss: 1.4176667928695679
Epoch 1/25: validation loss: 1.2221359014511108
Epoch 2/25: validation loss: 1.4202094078063965
Epoch 3/25: validation loss: 1.6523520946502686
Epoch 4/25: validation loss: 1.8723056316375732
Epoch 5/25: validation loss: 2.0713889598846436
Epoch 6/25: validation loss: 2.2501027584075928
Epoch 7/25: validation loss: 2.4108850955963135
Epoch 8/25: validation loss: 2.556072950363159
Epoch 9/25: validation loss: 2.6875805854797363
Epoch 10/25: validation loss: 2.8069610595703125
Epoch 11/25: validation loss: 2.915496587753296
Epoch 12/25: validation loss: 3.0142710208892822
Epoch 13/25: validation loss: 3.104220151901245
Epoch 14/25: validation loss: 3.186164379119873
Epoch 15/25: validation loss: 3.260831594467163
Epoch 16/25: validation loss: 3.328873872756958
Epoch 17/25: validation loss: 3.390878438949585
Epoch 18/25: validation loss: 3.4473750591278076
Epoch 19/25: validation loss: 3.4988455772399902
Epoch 20/25: validation loss: 3.54572

## Retrieval strategy

We now have a potentially improved embedding model, but need to use it for retrieval. Finish the retrieval function, which will take a user input and return relevant code snippets from the full list. Note: a vector database is not necessary.

In [94]:
all_snippets = []

for example in dataset:
    for snippet in example.snippets:
        all_snippets.append(snippet)

# Use similarity search with the embeddings model to retrieve relevant snippets
def retrieve_relevant_snippets(user_input: str):
    ui_encode = model.encode(user_input)
    ui_embed = torch.tensor(ui_encode, dtype=torch.float32)

    snip_encode = model.encode(all_snippets)
    snip_embed = torch.tensor(snip_encode, dtype=torch.float32)

    cos_sim = nn.functional.cosine_similarity(ui_embed, snip_embed)

    top_k = torch.topk(cos_sim, k=5).indices

    relevant_snippets = [(all_snippets[idx], cos_sim[idx].item()) for idx in top_k]

    return relevant_snippets

    
retrieve_relevant_snippets("How do we handle API errors?")

[("export class ErrorHandler {\n  catch(error: Error, req: Request, res: Response, next: NextFunction) {\n    if (error instanceof ValidationError) {\n      return res.status(400).json({\n        type: 'ValidationError',\n        message: error.message\n      });\n    }\n    \n    if (error instanceof AuthenticationError) {\n      return res.status(401).json({\n        type: 'AuthenticationError',\n        message: 'Unauthorized'\n      });\n    }\n    \n    // Default error\n    res.status(500).json({\n      type: 'ServerError',\n      message: 'Internal server error'\n    });\n  }\n}",
  0.47227051854133606),
 ('export class ValidationMiddleware {\n  validate(schema: Joi.Schema) {\n    return (req: Request, res: Response, next: NextFunction) => {\n      const { error } = schema.validate(req.body);\n      \n      if (error) {\n        return res.status(400).json({\n          error: error.details[0].message\n        });\n      }\n      \n      next();\n    };\n  }\n}',
  0.377465367317

## Evaluate the new retrieval strategy

If the loss was lower by the last epoch, then we know that we improved the similarity function (at least for the validation set), but we still need a way of evaluating the retrieval strategy as a whole.

Your last task is to design an evaluation metric suitable for codebase retrieval, which we can run over the examples in the above dataset. The result of the evaluation should be a single number that attempts to represent the quality of the retrieval strategy.

In [None]:
def evaluate_retrieval_strategy(retrieval_strategy):
    total_rank = 0
    total_k = 0

    for example in dataset:
        user_input = example.user_input
        current_snippet = example.snippets
        relevant_snippets = retrieval_strategy(user_input)

        reciprocal_rank = 0
        for rank, (snippet, _) in enumerate(relevant_snippets):
            print(snippet)
            print(current_snippet)
            if snippet in current_snippet:
                reciprocal_rank = 1 / rank
                break
        total_rank += reciprocal_rank

        in_top_k = sum(1 for snippet in relevant_snippets if snippet in current_snippet)
        kscore = in_top_k / 5
        total_k += kscore

    mrr = total_rank / len(dataset)
    precision = total_k / len(dataset)
    return mrr, precision

result = evaluate_retrieval_strategy(retrieve_relevant_snippets)
print(result)

('def initiate_password_reset(email):\n    token = generate_reset_token()\n    send_reset_email(email, token)\n    store_reset_token(email, token, expiry=24*hours)\n    return True\n\ndef validate_reset_token(token, new_password):\n    if is_token_valid(token):\n        user = get_user_by_token(token)\n        update_password(user, new_password)\n        invalidate_token(token)\n        return True\n    return False', 0.5059391260147095)
['def initiate_password_reset(email):\n    token = generate_reset_token()\n    send_reset_email(email, token)\n    store_reset_token(email, token, expiry=24*hours)\n    return True\n\ndef validate_reset_token(token, new_password):\n    if is_token_valid(token):\n        user = get_user_by_token(token)\n        update_password(user, new_password)\n        invalidate_token(token)\n        return True\n    return False']
('export class UserService {\n  async register(userData: RegisterDTO): Promise<User> {\n    const existingUser = await this.userRepo.fin