In [1]:

!pip install --quiet llama-index  # main llamaindex library

!pip install --quiet llama-index-vector-stores-MongoDB # mongodb vector database

!pip install --quiet llama-index-llms-anthropic # anthropic LLM provider

!pip install --quiet llama-index-embeddings-openai # openai embedding provider

!pip install --quiet beautifulsoup4

!pip install --quiet pymongo pandas datasets # others


In [15]:
import os
import config

os.environ["ANTHROPIC_API_KEY"] = config.ANTHROPIC_API_KEY
os.environ["HF_TOKEN"] = config.HF_TOKEN
os.environ["OPENAI_API_KEY"] = config.OPENAI_API_KEY

In [19]:
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.anthropic import Anthropic
from llama_index.core import Settings
llm = Anthropic(model="claude-3-5-sonnet-20240620")

embed_model = OpenAIEmbedding(
    model="text-embedding-3-small",
    dimensions=256,
    embed_batch_size=10,
    openai_api_key=os.environ["OPENAI_API_KEY"]
)

Settings.embed_model = embed_model
Settings.llm = llm


In [6]:
from datasets import load_dataset
import pandas as pd

# https://huggingface.co/datasets/colin/PrimeVul?row=10
dataset = load_dataset("colin/PrimeVul", split="train", streaming=True)
dataset = dataset.take(50000)

# Convert the dataset to a pandas dataframe
dataset_df = pd.DataFrame(dataset)
dataset_df.head(5)

Unnamed: 0,idx,project,commit_id,project_url,commit_url,commit_message,target,func,func_hash,file_name,file_hash,cwe,cve,cve_desc,nvd_url
0,0,openssl,ca989269a2876bae79393bd54c3e72d49975fc75,https://github.com/openssl/openssl,https://git.openssl.org/gitweb/?p=openssl.git;...,Use version in SSL_METHOD not SSL structure.\n...,1,long ssl_get_algorithm2(SSL *s)\n {\n ...,2.550877e+38,,,[CWE-310],CVE-2013-6449,The ssl_get_algorithm2 function in ssl/s3_lib....,https://nvd.nist.gov/vuln/detail/CVE-2013-6449
1,1,savannah,190cef6eed37d0e73a73c1e205eb31d45ab60a3c,https://git.savannah.gnu.org/gitweb/?p=gnutls,https://git.savannah.gnu.org/gitweb/?p=gnutls....,,1,gnutls_session_get_data (gnutls_session_t sess...,2.660054e+38,,,[CWE-119],CVE-2011-4128,Buffer overflow in the gnutls_session_get_data...,https://nvd.nist.gov/vuln/detail/CVE-2011-4128
2,2,savannah,e82ef4545e9e98cbcb032f55d7c750b81e3a0450,https://git.savannah.gnu.org/gitweb/?p=gnutls,https://git.savannah.gnu.org/gitweb/?p=gnutls....,,1,gnutls_session_get_data (gnutls_session_t sess...,1.6261950000000001e+38,,,[CWE-119],CVE-2011-4128,Buffer overflow in the gnutls_session_get_data...,https://nvd.nist.gov/vuln/detail/CVE-2011-4128
3,3,savannah,075d7556964f5a871a73c22ac4b69f5361295099,https://git.savannah.gnu.org/gitweb/?p=gnutls,https://git.savannah.gnu.org/cgit/wget.git/com...,,1,"getftp (struct url *u, wgint passed_expected_b...",1.147531e+38,,,[CWE-200],CVE-2015-7665,Tails before 1.7 includes the wget program but...,https://nvd.nist.gov/vuln/detail/CVE-2015-7665
4,5,ghostscript,83d4dae44c71816c084a635550acc1a51529b881,http://git.ghostscript.com/?p=mupdf,http://git.ghostscript.com/?p=mupdf.git;a=comm...,,1,void fz_init_cached_color_converter(fz_context...,1.831395e+38,colorspace.c,6.862522e+36,[CWE-20],CVE-2018-1000040,"In MuPDF 1.12.0 and earlier, multiple use of u...",https://nvd.nist.gov/vuln/detail/CVE-2018-1000040


In [7]:
import json
from llama_index.core import Document
from llama_index.core.schema import MetadataMode

documents_json = dataset_df.to_json(orient='records')
documents_list = json.loads(documents_json)

llama_documents = []

maxSize = 0

for document in documents_list:
    # Convert complex objects to JSON strings
    for field in [
        "idx",
        "project",
        "commit_id",
        "project_url",
        "commit_url",
        "commit_message",
        "target",
        "func",
        "func_hash",
        "file_name",
        "file_hash",
        "cwe",
        "cve",
        "cve_desc",
        "nvd_url"
    ]:
        document[field] = json.dumps(document[field])

    # ["idx", "project", "commit_id", "project_url", "commit_url", "commit_message", "target", "func", "func_hash", "file_name", "file_hash", "cwe", "cve", "cve_desc", "nvd_url"]

    excludedDocs = 0

    # Keep commit message, target, func, func hash, cve, cve desc
    # Create a Document object
    llama_document = Document(
        text=document["commit_message"],
        metadata=document,
        excluded_llm_metadata_keys=["idx", "project", "commit_id", "project_url", "commit_url", "file_name", "file_hash", "cwe", "nvd_url"],
        excluded_embed_metadata_keys=["idx", "project", "commit_id", "project_url", "commit_url", "file_name", "file_hash", "cwe", "nvd_url"],
        metadata_template="{key}=>{value}",
        text_template="Metadata: {metadata_str}\n-----\nContent: {content}",
    )

    if len(llama_document.get_metadata_str()) < 10000:
        llama_documents.append(llama_document)
    else:
        excludedDocs += 1

    maxSize = max(maxSize, len(llama_document.get_metadata_str()))

# Observing input examples
print("\nThe LLM sees this: \n", llama_documents[0].get_content(metadata_mode=MetadataMode.LLM))
print("\nThe Embedding model sees this: \n", llama_documents[0].get_content(metadata_mode=MetadataMode.EMBED))



The LLM sees this: 
 Metadata: commit_message=>"Use version in SSL_METHOD not SSL structure.\n\nWhen deciding whether to use TLS 1.2 PRF and record hash algorithms\nuse the version number in the corresponding SSL_METHOD structure\ninstead of the SSL structure. The SSL structure version is sometimes\ninaccurate. Note: OpenSSL 1.0.2 and later effectively do this already.\n(CVE-2013-6449)"
target=>1
func=>" long ssl_get_algorithm2(SSL *s)\n        {\n        long alg2 = s->s3->tmp.new_cipher->algorithm2;\n       if (TLS1_get_version(s) >= TLS1_2_VERSION &&\n            alg2 == (SSL_HANDSHAKE_MAC_DEFAULT|TLS1_PRF))\n                return SSL_HANDSHAKE_MAC_SHA256 | TLS1_PRF_SHA256;\n        return alg2;\n\t}\n"
func_hash=>2.550877477e+38
cve=>"CVE-2013-6449"
cve_desc=>"The ssl_get_algorithm2 function in ssl/s3_lib.c in OpenSSL before 1.0.2 obtains a certain version number from an incorrect data structure, which allows remote attackers to cause a denial of service (daemon crash) via crafte

In [8]:
from llama_index.core.node_parser import TokenTextSplitter
from llama_index.core.schema import MetadataMode
from tqdm import tqdm

base_splitter = TokenTextSplitter(chunk_size=10000, chunk_overlap=200)

nodes = base_splitter.get_nodes_from_documents(llama_documents)

# Progress bar
pbar = tqdm(total=len(nodes), desc="Embedding Progress", unit="node")

for node in nodes:
    node_embedding = embed_model.get_text_embedding(
        node.get_content(metadata_mode=MetadataMode.EMBED)
    )
    node.embedding = node_embedding
    
    # Update the progress bar
    pbar.update(1)

# Close the progress bar
pbar.close()

print("Embedding process completed!")


Embedding Progress: 100%|██████████| 49091/49091 [3:42:54<00:00,  3.67node/s]   

Embedding process completed!





In [16]:
import pymongo

os.environ["MONGO_URI"] = config.MONGO_URI

def get_mongo_client(mongo_uri):
    """Establish and validate connection to the MongoDB."""
    
    client = pymongo.MongoClient(mongo_uri, appname="devrel.showcase.python")

    # Validate the connection
    ping_result = client.admin.command('ping')
    if ping_result.get('ok') == 1.0:
        # Connection successful
        print("Connection to MongoDB successful")
        return client
    else:
        print("Connection to MongoDB failed")
    return None


mongo_client = get_mongo_client(config.MONGO_URI)

DB_NAME = "Claude"
COLLECTION_NAME = "PrimeVulData"

db = mongo_client.get_database(DB_NAME)
collection = db.get_collection(COLLECTION_NAME)


Connection to MongoDB successful


In [17]:
from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch

vector_store = MongoDBAtlasVectorSearch(
    mongo_client, 
    db_name=DB_NAME, 
    collection_name=COLLECTION_NAME, 
    vector_index_name="vector_index"
)

# vector_store.add(nodes)

In [20]:
from llama_index.core import VectorStoreIndex
from llama_index.core.tools import QueryEngineTool, ToolMetadata

index = VectorStoreIndex.from_vector_store(vector_store)
query_engine = index.as_query_engine(similarity_top_k=5, llm=llm)

query_engine_tool = QueryEngineTool(
    query_engine=query_engine,
    metadata=ToolMetadata(
        name="knowledge_base",
        description=(
            "Detects code vulnerabilities and details specific vulnerability"
            "Use text input"
        ),
    ),
)


In [21]:
from llama_index.core.agent import FunctionCallingAgentWorker

agent_worker = FunctionCallingAgentWorker.from_tools(
    [query_engine_tool], llm=llm, verbose=True
)
agent = agent_worker.as_agent()


In [51]:
# Testing the agent

response = agent.chat( 
"""
Can you determine the CWE for this:
String sessionID = generateSessionId();Cookie c = new Cookie("session_id", sessionID);response.addCookie(c);
"""
)

print(str(response))


Added user message to memory: 
Can you determine the CWE for this:
String sessionID = generateSessionId();Cookie c = new Cookie("session_id", sessionID);response.addCookie(c);

=== LLM Response ===
Certainly! I'll analyze this code snippet and use the knowledge base to determine the most likely CWE (Common Weakness Enumeration) associated with it. Let me query the database for this information.
=== Calling Function ===
Calling function: knowledge_base with args: {"input": "Determine the CWE for this code:\nString sessionID = generateSessionId();\nCookie c = new Cookie(\"session_id\", sessionID);\nresponse.addCookie(c);"}
=== Function Output ===
Based on the code snippet provided, the potential Common Weakness Enumeration (CWE) that could apply is CWE-614: Sensitive Cookie in HTTPS Session Without 'Secure' Flag.

This code creates a session ID cookie and adds it to the response, but it doesn't set any security attributes on the cookie. Specifically, it's missing the 'Secure' flag, which

In [23]:
import os

# Define the folder path
folder_path = "./data_samples/use/cwe_samples"

file_count = 0

# Iterate over all files in the folder
results = {}
for file_name in os.listdir(folder_path):
    file_path = os.path.join(folder_path, file_name)
    for test_file in os.listdir(file_path):
        test_file_path = os.path.join(file_path, test_file)
        if os.path.isfile(test_file_path) and file_count < 10:
            with open(test_file_path, 'r') as file:
                code_snippet = file.read()
                # Query the agent with the code snippet
                response = agent.chat(f"""Given a segment of code, determine if a vulnerability is present.
If you find that a vulnerability is present, then describe which lines contain the vulnerability.
If you find that a vulnerability is present, then describe what the vulnerability is.
If you find that a vulnerability is present, then describe recommendations for fixing the vulnerability.
Think through your response to ensure that it is as accurate as possible.

Additional context:
You will be given code snippets from hunks of github commits.
This means you may be missing variable declarations or definitions, import statements, or other such items.
Ignore warnings or issues that stem from you only evaluating a code snippet.
Do not concern yourself with how the code sample fits in to a larger file or project.
Focus on identifying vulnerabilities within the code snippet.
There are immense stakes at risk, and you cannot afford to be incorrect.

The format of your response must exactly follow this template. Do not include ANY response other than this template.-
VULNERABLE: YES/NO
VULNERABLE_LINES: LineNumbers/None
VULNERABILITY_DESCRIPTION:
Description of the vulnerability
EXPLANATION:
Provide a more detailed explanation of your analysis here.
RECOMMENDATIONS:
Include recommended fixes for this code.


Here is the segment you will be evaluating:\n{code_snippet}""")
                print(response)
                # results[file_name] = str(response)
                file_count += 1
                # print(code_snippet)


Added user message to memory: Given a segment of code, determine if a vulnerability is present.
If you find that a vulnerability is present, then describe which lines contain the vulnerability.
If you find that a vulnerability is present, then describe what the vulnerability is.
If you find that a vulnerability is present, then describe recommendations for fixing the vulnerability.
Think through your response to ensure that it is as accurate as possible.

Additional context:
You will be given code snippets from hunks of github commits.
This means you may be missing variable declarations or definitions, import statements, or other such items.
Do not concern yourself with how the code sample fits in to a larger file or project.
Focus on identifying vulnerabilities within the code snippet.
There are immense stakes at risk, and you cannot afford to be incorrect.

The format of your response must exactly follow this template. Do not include ANY response other than this template.-
VULNERABLE