In [1]:
pip install datasets

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [2]:
from datasets import load_dataset
dataset = load_dataset("code_search_net", "ruby")

In [3]:
!pip install faiss-cpu
!pip install -U sentence-transformers

Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


In [4]:
import numpy as np
import torch
import os
import pandas as pd
import faiss
import time
from sentence_transformers import SentenceTransformer

In [5]:
documents = dataset['train']['func_documentation_string']
documents

['Expose one or more attributes within a block. Old values are returned after the block concludes.\n Example demonstrating the common use of needing to set Current attributes outside the request-cycle:\n\n   class Chat::PublicationJob < ApplicationJob\n     def perform(attributes, room_number, creator)\n       Current.set(person: creator) do\n         Chat::Publisher.publish(attributes: attributes, room_number: room_number)\n       end\n     end\n   end',
 "Accepts a custom Rack environment to render templates in.\n It will be merged with the default Rack environment defined by\n +ActionController::Renderer::DEFAULTS+.\n Render templates with any options from ActionController::Base#render_to_string.\n\n The primary options are:\n * <tt>:partial</tt> - See <tt>ActionView::PartialRenderer</tt> for details.\n * <tt>:file</tt> - Renders an explicit template file. Add <tt>:locals</tt> to pass in, if so desired.\n   It shouldn’t be used directly with unsanitized user input due to lack of val

In [6]:
model = SentenceTransformer('BAAI/bge-base-en-v1.5')

In [7]:
model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': True}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  (2): Normalize()
)

In [8]:
encoded_data = model.encode(documents)
encoded_data

array([[ 0.01868366,  0.03399614,  0.01133468, ..., -0.03731885,
        -0.03780966,  0.00768813],
       [ 0.00994026,  0.00419443, -0.03402036, ..., -0.01948973,
         0.01017335,  0.02251645],
       [ 0.03411267, -0.0559819 , -0.00905228, ...,  0.00458416,
         0.02205012, -0.00351053],
       ...,
       [-0.01850404, -0.00849332, -0.00110278, ...,  0.0035976 ,
        -0.01985007,  0.02744106],
       [-0.04316016, -0.00474268, -0.02455162, ...,  0.02463778,
         0.00985802, -0.00188654],
       [-0.05866373, -0.01208141, -0.00431432, ...,  0.00159845,
         0.02519717, -0.02934378]], dtype=float32)

In [9]:
index = faiss.IndexIDMap(faiss.IndexFlatIP(model.get_sentence_embedding_dimension()))
index.add_with_ids(encoded_data, np.array(range(0, len(documents))))

In [10]:
#serializing index to export it across different host
faiss.write_index(index, 'sample_documents')

#de-serializing the index
index = faiss.read_index('sample_documents')

In [11]:
def semantic_search(query):
    t = time.time()
    query_vector = model.encode([query])
    # Search for top k results
    k = 100
    top_k = index.search(query_vector, k)

    return top_k[1].tolist()[0]
#[documents[_id] for _id in top_k[1].tolist()[0]])
    

In [12]:
query = "enumerable"
s_results = semantic_search(query)
s_results

[43259,
 42390,
 42720,
 30079,
 46812,
 38981,
 5229,
 27668,
 35569,
 30080,
 31973,
 14348,
 42645,
 43197,
 5569,
 14369,
 309,
 39285,
 30095,
 12269,
 39284,
 34339,
 5162,
 23291,
 39759,
 39732,
 3891,
 3890,
 44913,
 31120,
 42646,
 40730,
 41791,
 18359,
 25530,
 38330,
 31976,
 46034,
 35138,
 36845,
 22540,
 34234,
 32927,
 40294,
 30096,
 43084,
 23283,
 12518,
 38973,
 44726,
 33079,
 38225,
 38224,
 26087,
 36996,
 29467,
 34232,
 2886,
 45873,
 38982,
 43054,
 5161,
 38223,
 13462,
 24996,
 30694,
 29469,
 40017,
 46429,
 43055,
 34283,
 17066,
 30944,
 47124,
 46505,
 34235,
 46897,
 34424,
 12764,
 18007,
 29045,
 38531,
 28329,
 25663,
 46773,
 48722,
 26785,
 597,
 25497,
 27794,
 823,
 23282,
 26430,
 19122,
 43945,
 11780,
 16848,
 29461,
 24182,
 35482]

In [13]:
func_tokens = dataset['train']['func_code_tokens']

In [14]:
from nltk.util import bigrams

def generate_bigrams(token_list):
    return [" ".join(bigram) for bigram in list(bigrams(token_list))]

In [15]:
def inverted_index(code_tokens):
    inverted_index = {}
    
    for i in range(len(code_tokens)):
        code = code_tokens[i]
        bi = generate_bigrams(code)
        all_token = code + bi
        for token in all_token:
            if token in inverted_index.keys():
                inverted_index[token].append(i)
            else:
                inverted_index[token] = [i]
    
    return inverted_index

In [16]:
inverted_ind = inverted_index(func_tokens)

In [17]:
import math
total_documents = len(func_tokens)

def compute_tfidf(word, index):
    
    documents_with_term = len(inverted_ind[word])
    
    inner = total_documents/documents_with_term

    doc_length = len(func_tokens[index])
    
    document_frequency = inverted_ind[word].count(index)
    
    total = (document_frequency / doc_length) * math.log(inner)
    return total

In [18]:
def tfidf_search(user_input):
    
    import numpy as np
    
    track_docs = [0] * len(func_tokens)
    
    for word in (user_input.split() + generate_bigrams(user_input.split())):
        if word in inverted_ind:
            for doc_index in inverted_ind[word]:
                track_docs[doc_index] += compute_tfidf(word, doc_index)
                
    sorted_idx = list(np.argsort(track_docs)[::-1])[0:100]

    return sorted_idx

In [19]:
t_results = tfidf_search("enumerable")
t_results

[3891,
 3890,
 34424,
 42719,
 42502,
 38981,
 48790,
 16259,
 16265,
 16264,
 16263,
 16262,
 16261,
 16260,
 16257,
 16258,
 16267,
 16256,
 16255,
 16254,
 16253,
 16252,
 16251,
 16250,
 16249,
 16248,
 16266,
 16269,
 16268,
 16246,
 16288,
 16287,
 16286,
 16285,
 16284,
 16283,
 16282,
 16281,
 16280,
 16279,
 16278,
 16277,
 16276,
 16275,
 16274,
 16273,
 16272,
 16271,
 16270,
 16247,
 16245,
 16290,
 16210,
 16219,
 16218,
 16217,
 16216,
 16215,
 16214,
 16213,
 16212,
 16211,
 16209,
 16221,
 16208,
 16207,
 16206,
 16205,
 16204,
 16203,
 16202,
 16201,
 16200,
 16220,
 16222,
 16244,
 16234,
 16243,
 16242,
 16241,
 16240,
 16239,
 16238,
 16237,
 16236,
 16235,
 16233,
 16223,
 16232,
 16231,
 16230,
 16229,
 16228,
 16227,
 16226,
 16225,
 16224,
 16289,
 16292,
 16291]

In [20]:
def search_results(user_input):
    
    s = semantic_search(user_input)
    t = tfidf_search(user_input)
    overlap_results = list(set(s) & set(t))
    
    top_10_docs = overlap_results
        
    while len(top_10_docs) < 10:
        for i in s:
            if i not in top_10_docs:
                top_10_docs.append(i)
    
    if len(top_10_docs) > 10:
        top_10_docs = overlap_results[:10]
        
    function_name = []
    doc_string = []
    for i in top_10_docs:
        function_name.append(dataset['train']['func_name'][i])
        doc_string.append(dataset['train']['func_documentation_string'][i])
        
    results_df = pd.DataFrame({'Document': top_10_docs, 'Function': function_name, 'Documentation': doc_string})
    
    return results_df

In [21]:
search_results("enumerable")

Unnamed: 0,Document,Function,Documentation
0,34424,Pandata.DataFormatter.custom_sort,Sorts alphabetically ignoring the initial 'The...
1,3890,Twitter.Utils.flat_pmap,Returns a new array with the concatenated resu...
2,3891,Twitter.Utils.pmap,Returns a new array with the results of runnin...
3,38981,TeradataExtractor.Query.enumerable,"returns an enumerable, each element of which i..."
4,43259,Doublylinkedlist.Doublylinkedlist.each,Método para que la lista sea enumerable
5,42390,Yargi.ElementSet.grep,See Enumerable.grep
6,42720,MMETools.Enumerable.classify,Interessant iterador que classifica un enumera...
7,30079,Wbem.WsmanClient.each_instance,Enumerate instances
8,46812,StixSchemaSpy.SimpleType.enumeration_values,Returns the list of values for this enumeration
9,5229,Magick.ImageList.reject,override Enumerable's reject
