In [3]:

!pip install --quiet llama-index  # main llamaindex library

!pip install --quiet llama-index-vector-stores-MongoDB # mongodb vector database

!pip install --quiet llama-index-llms-anthropic # anthropic LLM provider

!pip install --quiet llama-index-embeddings-openai # openai embedding provider

!pip install --quiet beautifulsoup4

!pip install --quiet pymongo pandas datasets # others


In [4]:
import os
import config

os.environ["ANTHROPIC_API_KEY"] = config.ANTHROPIC_API_KEY
os.environ["HF_TOKEN"] = config.HF_TOKEN
os.environ["OPENAI_API_KEY"] = config.OPENAI_API_KEY

In [5]:
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.anthropic import Anthropic
from llama_index.core import Settings
llm = Anthropic(model="claude-3-5-sonnet-20240620")

embed_model = OpenAIEmbedding(
    model="text-embedding-3-small",
    dimensions=256,
    embed_batch_size=10,
    openai_api_key=os.environ["OPENAI_API_KEY"]
)

Settings.embed_model = embed_model
Settings.llm = llm


In [6]:
from datasets import load_dataset
import pandas as pd

# https://huggingface.co/datasets/MongoDB/airbnb_embeddings
dataset = load_dataset("colin/PrimeVul", split="train", streaming=True)
dataset = dataset.take(10000)

# Convert the dataset to a pandas dataframe
dataset_df = pd.DataFrame(dataset)
dataset_df.head(5)

Unnamed: 0,idx,project,commit_id,project_url,commit_url,commit_message,target,func,func_hash,file_name,file_hash,cwe,cve,cve_desc,nvd_url
0,0,openssl,ca989269a2876bae79393bd54c3e72d49975fc75,https://github.com/openssl/openssl,https://git.openssl.org/gitweb/?p=openssl.git;...,Use version in SSL_METHOD not SSL structure.\n...,1,long ssl_get_algorithm2(SSL *s)\n {\n ...,2.550877e+38,,,[CWE-310],CVE-2013-6449,The ssl_get_algorithm2 function in ssl/s3_lib....,https://nvd.nist.gov/vuln/detail/CVE-2013-6449
1,1,savannah,190cef6eed37d0e73a73c1e205eb31d45ab60a3c,https://git.savannah.gnu.org/gitweb/?p=gnutls,https://git.savannah.gnu.org/gitweb/?p=gnutls....,,1,gnutls_session_get_data (gnutls_session_t sess...,2.660054e+38,,,[CWE-119],CVE-2011-4128,Buffer overflow in the gnutls_session_get_data...,https://nvd.nist.gov/vuln/detail/CVE-2011-4128
2,2,savannah,e82ef4545e9e98cbcb032f55d7c750b81e3a0450,https://git.savannah.gnu.org/gitweb/?p=gnutls,https://git.savannah.gnu.org/gitweb/?p=gnutls....,,1,gnutls_session_get_data (gnutls_session_t sess...,1.6261950000000001e+38,,,[CWE-119],CVE-2011-4128,Buffer overflow in the gnutls_session_get_data...,https://nvd.nist.gov/vuln/detail/CVE-2011-4128
3,3,savannah,075d7556964f5a871a73c22ac4b69f5361295099,https://git.savannah.gnu.org/gitweb/?p=gnutls,https://git.savannah.gnu.org/cgit/wget.git/com...,,1,"getftp (struct url *u, wgint passed_expected_b...",1.147531e+38,,,[CWE-200],CVE-2015-7665,Tails before 1.7 includes the wget program but...,https://nvd.nist.gov/vuln/detail/CVE-2015-7665
4,5,ghostscript,83d4dae44c71816c084a635550acc1a51529b881,http://git.ghostscript.com/?p=mupdf,http://git.ghostscript.com/?p=mupdf.git;a=comm...,,1,void fz_init_cached_color_converter(fz_context...,1.831395e+38,colorspace.c,6.862522e+36,[CWE-20],CVE-2018-1000040,"In MuPDF 1.12.0 and earlier, multiple use of u...",https://nvd.nist.gov/vuln/detail/CVE-2018-1000040


In [7]:
import json
from llama_index.core import Document
from llama_index.core.schema import MetadataMode

documents_json = dataset_df.to_json(orient='records')
documents_list = json.loads(documents_json)

llama_documents = []

maxSize = 0

for document in documents_list:
    # Convert complex objects to JSON strings
    for field in [
        "idx",
        "project",
        "commit_id",
        "project_url",
        "commit_url",
        "commit_message",
        "target",
        "func",
        "func_hash",
        "file_name",
        "file_hash",
        "cwe",
        "cve",
        "cve_desc",
        "nvd_url"
    ]:
        document[field] = json.dumps(document[field])

    # ["idx", "project", "commit_id", "project_url", "commit_url", "commit_message", "target", "func", "func_hash", "file_name", "file_hash", "cwe", "cve", "cve_desc", "nvd_url"]

    excludedDocs = 0

    # Keep commit message, target, func, func hash, cve, cve desc
    # Create a Document object
    llama_document = Document(
        text=document["commit_message"],
        metadata=document,
        excluded_llm_metadata_keys=["idx", "project", "commit_id", "project_url", "commit_url", "file_name", "file_hash", "cwe", "nvd_url"],
        excluded_embed_metadata_keys=["idx", "project", "commit_id", "project_url", "commit_url", "file_name", "file_hash", "cwe", "nvd_url"],
        metadata_template="{key}=>{value}",
        text_template="Metadata: {metadata_str}\n-----\nContent: {content}",
    )

    if len(llama_document.get_metadata_str()) < 10000:
        llama_documents.append(llama_document)
    else:
        excludedDocs += 1

    maxSize = max(maxSize, len(llama_document.get_metadata_str()))

# Observing input examples
print("\nThe LLM sees this: \n", llama_documents[0].get_content(metadata_mode=MetadataMode.LLM))
print("\nThe Embedding model sees this: \n", llama_documents[0].get_content(metadata_mode=MetadataMode.EMBED))

print("\nGreatest Metadata Size:", maxSize)
print("\nExlucded",excludedDocs)



The LLM sees this: 
 Metadata: commit_message=>"Use version in SSL_METHOD not SSL structure.\n\nWhen deciding whether to use TLS 1.2 PRF and record hash algorithms\nuse the version number in the corresponding SSL_METHOD structure\ninstead of the SSL structure. The SSL structure version is sometimes\ninaccurate. Note: OpenSSL 1.0.2 and later effectively do this already.\n(CVE-2013-6449)"
target=>1
func=>" long ssl_get_algorithm2(SSL *s)\n        {\n        long alg2 = s->s3->tmp.new_cipher->algorithm2;\n       if (TLS1_get_version(s) >= TLS1_2_VERSION &&\n            alg2 == (SSL_HANDSHAKE_MAC_DEFAULT|TLS1_PRF))\n                return SSL_HANDSHAKE_MAC_SHA256 | TLS1_PRF_SHA256;\n        return alg2;\n\t}\n"
func_hash=>2.550877477e+38
cve=>"CVE-2013-6449"
cve_desc=>"The ssl_get_algorithm2 function in ssl/s3_lib.c in OpenSSL before 1.0.2 obtains a certain version number from an incorrect data structure, which allows remote attackers to cause a denial of service (daemon crash) via crafte

In [8]:
from llama_index.core.node_parser import TokenTextSplitter
from llama_index.core.schema import MetadataMode
from tqdm import tqdm

base_splitter = TokenTextSplitter(chunk_size=10000, chunk_overlap=200)

nodes = base_splitter.get_nodes_from_documents(llama_documents)

# Progress bar
pbar = tqdm(total=len(nodes), desc="Embedding Progress", unit="node")

for node in nodes:
    node_embedding = embed_model.get_text_embedding(
        node.get_content(metadata_mode=MetadataMode.EMBED)
    )
    node.embedding = node_embedding
    
    # Update the progress bar
    pbar.update(1)

# Close the progress bar
pbar.close()

print("Embedding process completed!")


Embedding Progress: 100%|██████████| 9572/9572 [1:01:07<00:00,  2.61node/s]   

Embedding process completed!





In [20]:
import pymongo

os.environ["MONGO_URI"] = config.MONGO_URI

def get_mongo_client(mongo_uri):
    """Establish and validate connection to the MongoDB."""
    
    client = pymongo.MongoClient(mongo_uri, appname="devrel.showcase.python")

    # Validate the connection
    ping_result = client.admin.command('ping')
    if ping_result.get('ok') == 1.0:
        # Connection successful
        print("Connection to MongoDB successful")
        return client
    else:
        print("Connection to MongoDB failed")
    return None


mongo_client = get_mongo_client(config.MONGO_URI)

DB_NAME = "airbnb"
COLLECTION_NAME = "listings_reviews"

db = mongo_client.get_database(DB_NAME)
collection = db.get_collection(COLLECTION_NAME)


Connection to MongoDB successful


In [21]:
from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch

vector_store = MongoDBAtlasVectorSearch(
    mongo_client, 
    db_name=DB_NAME, 
    collection_name=COLLECTION_NAME, 
    index_name="vector_index"
)

vector_store.add(nodes)

index_name is deprecated. Please use vector_index_name
vector_index_name and index_name both specified. Will use vector_index_name


['7a1a2f46-01b5-45bb-9b06-08757aa5f85e',
 '0ce725c7-7895-4df9-81de-7ea67ee5a03b',
 'c0d2b4b0-5c5c-48e7-b84f-83ffcff31b83',
 '2d825c01-ef3e-4c7a-91c1-93fc10769e17',
 '33c09379-014b-42f3-a65a-c78c443f653f',
 'bcc5ecaf-c8e7-49a8-bb2e-f8f38fffba76',
 'df270b54-dca7-4b85-a6e3-59155db71a52',
 'e7e6479a-1382-495d-ae75-f0b08dbbfa5d',
 '0aca5933-93e2-461d-8f8b-4ce47fe480ed',
 'b0b337ac-10ba-4552-910c-dcb88f2a7f5e',
 '2b99869b-bcbe-4e00-bc72-1c60550240fc',
 '13e124da-1fe7-4b4a-89c0-355db7ef2500',
 '738f9df4-4977-4b09-a531-2da8a22c6d98',
 'ceca1c9b-4375-43d8-8218-e6531aec3e3d',
 '97a59aaf-837d-4125-9d06-915c4e127d8d',
 '639be01e-8b96-4ef8-8679-89aadc34281a',
 '043c58ba-358f-4152-9a60-729883e3aa6e',
 'fae0495d-782c-4a12-967a-e3c99d8c6e0e',
 '137ca518-8699-4de5-83f5-77d5a95f3cbc',
 'c1ad08b5-a7a9-4201-81f4-7a08f42694e8',
 'e2705e29-135c-4be8-8edd-4ac2e430cae7',
 '1acff592-2e8f-4b67-bf77-37dc66005124',
 '3251b552-b927-4a3e-b660-b7397d4c918b',
 'a211f66b-7074-420f-8276-56c68620a815',
 'b55e3e5f-7710-

In [22]:
from llama_index.core import VectorStoreIndex
from llama_index.core.tools import QueryEngineTool, ToolMetadata

index = VectorStoreIndex.from_vector_store(vector_store)
query_engine = index.as_query_engine(similarity_top_k=5, llm=llm)

query_engine_tool = QueryEngineTool(
    query_engine=query_engine,
    metadata=ToolMetadata(
        name="knowledge_base",
        description=(
            "Provides information about Airbnb listings and reviews."
            "Use a detailed plain text question as input to the tool."
        ),
    ),
)


In [23]:
from llama_index.core.agent import FunctionCallingAgentWorker

agent_worker = FunctionCallingAgentWorker.from_tools(
    [query_engine_tool], llm=llm, verbose=True
)
agent = agent_worker.as_agent()


In [24]:
response = agent.chat("""
#include <unistd.h>
#include <stdlib.h>
#include <stdlib.h>
void *mymalloc(unsigned int size) { return malloc(size); }

int main()
{
    char *buf;
    size_t len;
    read(0, &len, sizeof(len));
    /* we forgot to check the maximum length */
    /* 64-bit size_t gets truncated to 32-bit unsigned int */
    buf = mymalloc(len);
    read(0, buf, len);
    return 0;
}
                      
                      There is a vulnerability in this code, what is it and specify the CVE ID
                      """)
print(str(response))


Added user message to memory: 
#include <unistd.h>
#include <stdlib.h>
#include <stdlib.h>
void *mymalloc(unsigned int size) { return malloc(size); }

int main()
{
    char *buf;
    size_t len;
    read(0, &len, sizeof(len));
    /* we forgot to check the maximum length */
    /* 64-bit size_t gets truncated to 32-bit unsigned int */
    buf = mymalloc(len);
    read(0, buf, len);
    return 0;
}
                      
                      There is a vulnerability in this code, what is it and specify the CVE ID
                      
=== LLM Response ===
To answer your question about the vulnerability in this code and provide a specific CVE ID, I'll need to consult the knowledge base. Let me do that for you.
=== Calling Function ===
Calling function: knowledge_base with args: {"input": "What is the vulnerability in the given C code that involves truncation of a 64-bit size_t to a 32-bit unsigned int in a memory allocation function? Please provide the specific CVE ID if possible."}
==