## Imports

In [1]:
pip install datasets

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [2]:
from datasets import load_dataset
dataset = load_dataset("code_search_net", "ruby")

In [3]:
!pip install faiss-cpu
!pip install -U sentence-transformers

Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


In [4]:
import numpy as np
import torch
import os
import pandas as pd
import faiss
import time
from sentence_transformers import SentenceTransformer

## Semantic Search

In [5]:
documents = dataset['train']['func_documentation_string']

In [6]:
model = SentenceTransformer('BAAI/bge-base-en-v1.5')

In [7]:
model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': True}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  (2): Normalize()
)

In [8]:
encoded_data = model.encode(documents)

In [9]:
index = faiss.IndexIDMap(faiss.IndexFlatIP(model.get_sentence_embedding_dimension()))
index.add_with_ids(encoded_data, np.array(range(0, len(documents))))

In [10]:
#serializing index to export it across different host
faiss.write_index(index, 'sample_documents')

#de-serializing the index
index = faiss.read_index('sample_documents')

In [15]:
#version that is ordered by document id

def semantic_search(query):
    t = time.time()
    query_vector = model.encode([query])
    
    # Search for all results (remove the 'k' limit)
    top_k = index.search(query_vector, index.ntotal)
    
    # Extract document IDs and scores for all documents
    document_ids = top_k[1].tolist()[0]
    scores = top_k[0].tolist()[0]
    
    # Create a list of tuples containing document IDs and scores for all documents
    results = [(doc_id, score) for doc_id, score in zip(document_ids, scores)]
    
    # Sort the results by document ID
    results.sort(key=lambda x: x[0])
    
    # Return just the scores
    semantic_scores = [i[1] for i in results]

    # normalize semantic scores
    max_score = max(semantic_scores)
    normalized_semantic_scores = [score / max_score for score in semantic_scores]
    normalized_semantic_scores
    
    return normalized_semantic_scores

In [316]:
semantic_search("enumerable")

[0.5416221724600533,
 0.5807885261964408,
 0.4734582584903372,
 0.45171259417869786,
 0.46962370648423013,
 0.5173304057101725,
 0.48209753110720666,
 0.4597470821218634,
 0.47024407109098026,
 0.43402076307534077,
 0.43802731173158066,
 0.5497931028471692,
 0.4382009789572848,
 0.5873963062847779,
 0.5942987613771644,
 0.5891135853384477,
 0.6320330686802671,
 0.576444085472457,
 0.4790145928407847,
 0.5504531254653098,
 0.5775816748028539,
 0.5977843809900122,
 0.6840237773747413,
 0.5003240916608499,
 0.5268716801850126,
 0.4633748459257204,
 0.5543101939247703,
 0.5698768350001562,
 0.6557173269919253,
 0.7246017674689545,
 0.6601580800292859,
 0.7163817366867837,
 0.7268045303104147,
 0.7322093507199091,
 0.5010128772323066,
 0.5985741274329572,
 0.4990434139010772,
 0.40897796365553896,
 0.6618133298856528,
 0.6647651642865807,
 0.5991679807317266,
 0.6665512274735595,
 0.5853292958669234,
 0.5975440360088933,
 0.6827796343763778,
 0.5961996221593859,
 0.5658309188674079,
 0.6275

## BM-25 Search

In [17]:
func_tokens = dataset['train']['func_code_tokens']

In [18]:
pip install rank_bm25

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Defaulting to user installation because normal site-packages is not writeable
Collecting rank_bm25
  Downloading rank_bm25-0.2.2-py3-none-any.whl (8.6 kB)
Installing collected packages: rank-bm25
Successfully installed rank-bm25-0.2.2
Note: you may need to restart the kernel to use updated packages.


In [19]:
from rank_bm25 import BM25Okapi

In [20]:
bm25 = BM25Okapi(func_tokens)

In [21]:
def bm25_search(user_input):
    
    doc_scores = bm25.get_scores(user_input)
    max_score = max(doc_scores)

    # normalize BM25 scores
    normalized_doc_scores = [score / max_score for score in doc_scores]
    
    return normalized_doc_scores

In [352]:
bm25_search("max")

[0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.1499767387289621,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.26880636432044774,
 0.2664250820529291,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.25333664607045314,
 0.0,
 0.0,
 0.269611087254954,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.32909355032903864,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0

## Combined Search

In [None]:
# def search_results(user_input):
    
#     s = semantic_search(user_input)
#     t = tfidf_search(user_input)
#     overlap_results = list(set(s) & set(t))
    
#     top_10_docs = overlap_results
        
#     while len(top_10_docs) < 10:
#         for i in s:
#             if i not in top_10_docs:
#                 top_10_docs.append(i)
    
#     if len(top_10_docs) > 10:
#         top_10_docs = overlap_results[:10]
        
#     function_name = []
#     doc_string = []
#     for i in top_10_docs:
#         function_name.append(dataset['train']['func_name'][i])
#         doc_string.append(dataset['train']['func_documentation_string'][i])
        
#     results_df = pd.DataFrame({'Document': top_10_docs, 'Function': function_name, 'Documentation': doc_string})
    
#     return results_df

In [467]:
def find_quartiles(data):
    # Filter out zeros and sort the remaining data
    filtered_sorted_data = sorted([x for x in data if x != 0])

    n = len(filtered_sorted_data)
    if n == 0:
        # Handle the case where all values are zero
        return [0 for _ in data]

    # Calculate quartile breakpoints
    q1 = filtered_sorted_data[int(n * 0.25) - 1]
    q2 = filtered_sorted_data[int(n * 0.5) - 1]
    q3 = filtered_sorted_data[int(n * 0.75) - 1]

    # Assign quartiles including zeros
    quartiles = []
    for value in data:
        if value <= q1:
            quartiles.append(0)
        elif value <= q2:
            quartiles.append(1)
        elif value <= q3:
            quartiles.append(2)
        else:
            quartiles.append(3)

    return quartiles

In [541]:
def search_results(sem_weight, bm_weight, user_input):
    
    sem = semantic_search(user_input)
    bm = bm25_search(user_input)
    weighted_sem = [i * sem_weight for i in sem]
    weighted_bm = [i * bm_weight for i in bm]
    weighted_avg = [weighted_sem[i]+ weighted_bm[i] for i in range(0, len(weighted_bm))]
    sum_weight = sem_weight + bm_weight 
    weighted_avg_norm = [i/sum_weight for i in weighted_avg]
    url = dataset["train"]['func_code_url']
    if not weighted_avg_norm or np.isnan(weighted_avg_norm).any():
        # Handle the empty or invalid input case
        return {} 
    
    try:
#         import math
#         labels = pd.qcut(weighted_avg_norm, q=4, labels=False, duplicates='drop')
#         output_dict = {url[i]: labels[i] for i in range(len(weighted_avg_norm))}
        labels = find_quartiles(weighted_avg_norm)
        output_dict = {url[i]: labels[i] for i in range(len(weighted_avg_norm))}
        
    except ValueError:
        # Handle the case where qcut fails
        return {}
    
    return output_dict

In [542]:
search_results(1, 0, "func").values()

dict_values([2, 2, 0, 1, 0, 2, 1, 2, 1, 1, 1, 1, 0, 1, 2, 2, 2, 2, 0, 2, 3, 3, 3, 1, 0, 2, 2, 0, 3, 3, 2, 2, 2, 3, 2, 0, 1, 0, 1, 3, 3, 3, 1, 3, 3, 0, 1, 3, 3, 1, 2, 0, 3, 1, 2, 1, 3, 2, 3, 3, 2, 2, 3, 0, 1, 0, 0, 3, 0, 3, 3, 2, 0, 2, 3, 1, 1, 2, 0, 2, 1, 0, 2, 0, 2, 1, 3, 1, 3, 1, 1, 3, 3, 2, 3, 2, 2, 1, 3, 3, 3, 1, 0, 3, 3, 3, 2, 3, 3, 1, 3, 3, 2, 3, 3, 3, 2, 1, 2, 3, 3, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 2, 3, 0, 1, 0, 3, 3, 3, 1, 3, 2, 2, 3, 2, 0, 1, 3, 0, 0, 3, 2, 3, 3, 1, 3, 3, 1, 2, 2, 0, 0, 0, 0, 0, 0, 0, 1, 0, 3, 1, 2, 3, 1, 1, 2, 2, 1, 2, 2, 2, 2, 3, 1, 1, 1, 0, 0, 3, 2, 0, 2, 2, 1, 1, 1, 2, 1, 0, 2, 0, 0, 0, 1, 3, 3, 2, 3, 2, 0, 0, 2, 2, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 1, 2, 3, 1, 3, 2, 3, 1, 3, 3, 3, 0, 1, 1, 1, 0, 3, 3, 3, 2, 2, 2, 2, 3, 3, 3, 3, 2, 1, 3, 1, 1, 1, 2, 3, 3, 3, 3, 2, 0, 3, 2, 2, 0, 3, 3, 2, 2, 1, 2, 3, 3, 0, 1, 3, 3, 3, 0, 2, 3, 3, 3, 3, 3, 2, 1, 3, 1, 2, 3, 3, 2, 1, 0, 2, 0, 3, 1, 3, 3, 3, 0, 3, 0, 2, 2, 0, 1, 0, 1, 1, 1, 3, 3, 1, 1, 1, 1, 

In [488]:
search_results(0, 1, "func").values()

dict_values([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [543]:
def precision(search_output, eval_dict):
    
    inter = list(search_output.keys() & eval_dict.keys())
    predicted_scores = [search_output[i] for i in search_output.keys() if i in inter]
    true_scores = [eval_dict[i] for i in eval_dict.keys() if i in inter]
    
    
    def is_relevant(score):
        return score >= 2  

    sorted_items = sorted(zip(predicted_scores, true_scores), reverse=True, key=lambda x: x[0])

    top_k_items = sorted_items

    relevant_count = sum(1 for _, true_score in top_k_items if is_relevant(true_score))
    return relevant_count/len(top_k_items)

In [549]:
def precision(search_output, eval_dict):
    
    inter = list(search_output.keys() & eval_dict.keys())
    predicted_scores = [search_output[i] for i in search_output.keys() if i in inter]
    true_scores = [eval_dict[i] for i in eval_dict.keys() if i in inter]
    
    
    def TP_def(pscore, tscore):
        if pscore >= 2:
            if tscore >= 2:
                return True
        else:
            return False
        
    def TN_def(pscore, tscore):
        if pscore < 2:
            if tscore < 2:
                return True
        else:
            return False
    
    
    sorted_items = sorted(zip(predicted_scores, true_scores), reverse=True, key=lambda x: x[0])
    
    TP = sum(1 for tp, ts in sorted_items if TP_def(tp,ts))
    TN = sum(1 for tp, ts in sorted_items if TN_def(tp,ts))

    return (TP+TN)/len(predicted_scores)

In [550]:
eval_df = pd.read_csv("annotationStore.csv") 
eval_ruby = eval_df[eval_df["Language"] == "Ruby"]

In [551]:
# def find_best_weights(sem_weight, bm_weight):
#     prec = []
#     for i in list(eval_ruby["Query"].unique()):
#         subset = eval_ruby[eval_ruby["Query"] == i]
#         if subset.shape[0] == 0:
#             continue
#         evals = pd.Series(subset.Relevance.values,index=subset.GitHubUrl).to_dict()
#         our_search = search_results(sem_weight, bm_weight, i)
#         inter = list(evals.keys() & our_search.keys())
#         if len(inter)==0:
#             continue
#         prec.append(precision(our_search, evals))
#         print(len(prec))
#     return sum(prec)/len(prec)

In [552]:
def find_best_weights(sem_weight, bm_weight):

    grouped_evals = eval_ruby.groupby('Query').apply(lambda x: pd.Series(x.Relevance.values, index=x.GitHubUrl).to_dict())

    prec = []
    search_results_cache = {}  

    for query, evals in grouped_evals.items():
        if (sem_weight, bm_weight, query) not in search_results_cache:
            search_results_cache[(sem_weight, bm_weight, query)] = search_results(sem_weight, bm_weight, query)
        
        our_search = search_results_cache[(sem_weight, bm_weight, query)]
        inter = set(evals.keys()) & set(our_search.keys())

        if inter:
            a = precision(our_search, evals)
            prec.append(a)
    return sum(prec) / len(prec) if prec else 0

In [None]:
find_best_weights(sem_weight, bm_weight)

In [562]:
import pandas as pd
import numpy as np

def grid_search(sem_weight_range, bm_weight_range, increment):
    best_precision = 0
    best_weights = (0, 0)

    for sem_weight in np.arange(*sem_weight_range, increment):
        bm_weight = 5 - sem_weight
        current_precision = find_best_weights(sem_weight, bm_weight)
        if current_precision > best_precision:
            best_precision = current_precision
            best_weights = (sem_weight, bm_weight)
        print(sem_weight, bm_weight)
        print(current_precision)

    return best_weights, best_precision

# Example usage
sem_weight_range = (0, 6)  # Define the range for sem_weight 
increment = 1           # Define the increment

best_weights, best_precision = grid_search(sem_weight_range, bm_weight_range, increment)

0 5
0.6674242424242425
1 4
0.27323232323232327
2 3
0.27323232323232327
3 2
0.27323232323232327
4 1
0.2681818181818182
5 0
0.2681818181818182


In [563]:
best_weights

(0, 5)

In [558]:
best_precision

0.6674242424242425

# End to End

In [565]:
def end_to_end():
    grouped_evals = eval_ruby.groupby('Query').apply(lambda x: pd.Series(x.Relevance.values, index=x.GitHubUrl).to_dict())

    prec = []
    search_results_cache = {}  

    sem_weight, bm_weight = best_weights
    
    for query, evals in grouped_evals.items():
        if (sem_weight, bm_weight, query) not in search_results_cache:
            search_results_cache[(sem_weight, bm_weight, query)] = search_results(sem_weight, bm_weight, query)
        
        our_search = search_results_cache[(sem_weight, bm_weight, query)]
        inter = set(evals.keys()) & set(our_search.keys())

        if inter:
            prec.append(precision(our_search, evals))
    return sum(prec) / len(prec) if prec else 0

In [566]:
end_to_end()

0.6674242424242425

## Search Results

In [21]:
search_results("enumerable")

Unnamed: 0,Document,Function,Documentation
0,34424,Pandata.DataFormatter.custom_sort,Sorts alphabetically ignoring the initial 'The...
1,3890,Twitter.Utils.flat_pmap,Returns a new array with the concatenated resu...
2,3891,Twitter.Utils.pmap,Returns a new array with the results of runnin...
3,38981,TeradataExtractor.Query.enumerable,"returns an enumerable, each element of which i..."
4,43259,Doublylinkedlist.Doublylinkedlist.each,Método para que la lista sea enumerable
5,42390,Yargi.ElementSet.grep,See Enumerable.grep
6,42720,MMETools.Enumerable.classify,Interessant iterador que classifica un enumera...
7,30079,Wbem.WsmanClient.each_instance,Enumerate instances
8,46812,StixSchemaSpy.SimpleType.enumeration_values,Returns the list of values for this enumeration
9,5229,Magick.ImageList.reject,override Enumerable's reject
