In [19]:
import chromadb
from chromadb.utils import embedding_functions
import pandas as pd

In [20]:
%load_ext dotenv
%dotenv

The dotenv extension is already loaded. To reload it, use:
  %reload_ext dotenv


In [21]:
cwe_df = pd.read_csv('data/all.csv')
print(cwe_df.head())

   CWE-ID                                               Name  \
0       5  J2EE Misconfiguration: Data Transmission Witho...   
1       6  J2EE Misconfiguration: Insufficient Session-ID...   
2       7   J2EE Misconfiguration: Missing Custom Error Page   
3       8  J2EE Misconfiguration: Entity Bean Declared Re...   
4       9  J2EE Misconfiguration: Weak Access Permissions...   

  Weakness Abstraction      Status  \
0              Variant       Draft   
1              Variant  Incomplete   
2              Variant  Incomplete   
3              Variant  Incomplete   
4              Variant       Draft   

                                         Description  \
0  Information sent over a network can be comprom...   
1  The J2EE application is configured to use an i...   
2  The default error page of a web application sh...   
3  When an application exposes a remote interface...   
4  If elevated access rights are assigned to EJB ...   

                                Extended Descript

In [24]:
import json
import glob
import random

# Now, let's select 1000 random GHSA-*.json files as a sample test set from data/advisories/github-reviewed
files = glob.glob('data/advisories/github-reviewed/*/*/*/*.json', recursive=True)

sample_files = random.sample(files, 1000)

# Some of these files contain a "Rejected reason" in their details, which make them unsuitable for our analysis.
# We will filter out those files.
sample_files = [file for file in sample_files if 'Rejected reason' not in open(file, 'r').read()]

print(f"Selected {len(sample_files)} files for analysis.")

# Lets generate a DataFrame by reading each file, extracting the "summary", "details", and "database_specific.cwe_ids" fields, and normalizing it
data = []
for file in sample_files:
    with open(file, 'r') as f:
        content = json.load(f)
        id = content.get('id', '')
        summary = content.get('summary', '')
        details = content.get('details', '')
        cwe_ids = content.get('database_specific', {}).get('cwe_ids', [])
        # Normalize cwe_ids to a comma-separated string, and strip out the prefix "CWE-"
        cwe_ids = ','.join([cwe_id.replace('CWE-', '') for cwe_id in cwe_ids])
        if len(cwe_ids) == 0:
            continue
        data.append({
            'id': id,
            'summary': summary,
            'details': details,
            'cwe_ids': cwe_ids
        })

print(f"Selected {len(data)} GHSAs for analysis")
sample_df = pd.DataFrame(data)
print(sample_df.head())

Selected 1000 files for analysis.
Selected 927 GHSAs for analysis
                    id                                            summary  \
0  GHSA-4jrx-5w4h-3gpm        Navidrome Parameter Tampering vulnerability   
1  GHSA-4363-x42f-xph6              Malicious Package in hw-trnasport-u2f   
2  GHSA-cxmj-qjv6-vx9p         mcstatic directory traversal vulnerability   
3  GHSA-xp26-p53h-6h2p  Improper Neutralization of Input During Web Pa...   
4  GHSA-8vj9-5v5q-fhch          Bonita cross-site scripting vulnerability   

                                             details cwe_ids  
0  ### Summary\nParameter tampering is a vulnerab...     200  
1  All versions of this package contained malware...     506  
2  A server directory traversal vulnerability was...      22  
3  An issue was discovered in lxml before 4.2.5. ...      79  
4  Bonita before 10.1.0.W11 allows stored XSS via...      79  


In [11]:
# Let's see how many tokens we have in all of the text
# If its small enough, we can just add it to the context as part of the prompt
# First, lets turn all of IDs, Descriptions, and Extended Descriptions into one text blob
cwe_text_blob = ""
for i, row in cwe_df.iterrows():
    cwe_text_blob += str(row['CWE-ID']) + "," + str(row['Description']) + "," + str(row['Extended Description']) + "\n"

import tiktoken
encoding = tiktoken.get_encoding("cl100k_base")
tokens = encoding.encode(cwe_text_blob)
print(f"Total number of tokens in text blob: {len(tokens)}")

Total number of tokens in text blob: 100979


That's quite a lot of tokens! So we can't just yeet everything into the prompt. We should do some preprocessing before we use it in the prompt.

# Helper Methods

In [25]:
# Some helper functions for everything
def generate_input_blob(row):
    return row['summary'] + "\n" + row['details']

In [56]:
from openai import OpenAI
import os

token = os.environ["GITHUB_TOKEN"]
endpoint = "https://models.github.ai/inference"
model = "openai/gpt-4.1"
client = OpenAI(
    base_url=endpoint,
    api_key=token,
)

def query_model(prompt, input_text, model=model, temperature=0.7):
    response = client.chat.completions.create(
        messages=[
            {
                "role": "system",
                "content": prompt,
            },
            {
                "role": "user",
                "content": input_text,
            }
        ],
        temperature=temperature,
        top_p=1.0,
        model=model
    )
    return response.choices[0].message.content.strip()

In [52]:
def test_accuracy(predictor, test_df, samples):
    sample_df = test_df.sample(n=samples)
    accurate = 0
    for idx, row in sample_df.iterrows():
        input_blob = generate_input_blob(row)
        expected_cwe_ids = row['cwe_ids'].split(',') if row['cwe_ids'] else []
        actual_cwe_id = predictor(input_blob)
        if actual_cwe_id in expected_cwe_ids:
            accurate += 1
    return accurate

## Embeddings Search Approach

In [31]:
client = chromadb.Client()
# Use OpenAI embedding function from chromadb
openai_ef = embedding_functions.DefaultEmbeddingFunction()
collection = client.get_or_create_collection(name="cwe_index")

# Prepare the text data: concatenate Description and Extended Description
texts = (
    cwe_df['Description'].fillna('') + " " + cwe_df['Extended Description'].fillna('')
).tolist()
metadatas = (
    cwe_df[['CWE-ID', 'Name']].fillna('').to_dict(orient='records')
)

# Generate embeddings
embeddings = openai_ef(texts)
collection.add(
    documents=texts,
    embeddings=embeddings,
    metadatas=metadatas,
    ids=cwe_df['CWE-ID'].astype(str).tolist()  # Use CWE-ID as the unique identifier
)

In [55]:
# Now, for each sample in the sample_df, we will find the most similar CWE entry.
# We will then compare the cwe_ids from the sample_df with the cwe_ids from the most similar CWE entry.
def embedding_search_predictor(input_text):
    query_result = collection.query(
        query_texts=[input_text],
        n_results=1,
        include=["documents", "metadatas"]
    )
    return query_result['ids'][0][0] if query_result['ids'] else None

total_samples = len(sample_df)
accurate = test_accuracy(embedding_search_predictor, sample_df, total_samples)
print(f"Accuracy: {accurate}/{total_samples} = {accurate / total_samples:.2%}")

Accuracy: 185/927 = 19.96%


## Zero-shot Prompting

In [78]:
def zero_shot_classification(input_text):
    prompt = (
        "You are an expert in CWE classification. "
        "Given a summary and details of a vulnerability, classify it into one or more CWE categories. "
        "Return only the numerical ID of the CWE."
        "If you cannot classify it, return 'None'. "
        "The input will be in the format: 'Summary: <summary> Details: <details>'."
    )
    return query_model(prompt, input_text)

total_samples = 20
accurate = test_accuracy(zero_shot_classification, sample_df, total_samples)
print(f"Accuracy: {accurate}/{total_samples} = {accurate / total_samples:.2%}")

Accuracy: 13/20 = 65.00%


## Prompting then Cosine Similarity

In [65]:
# Let's try cleaning up the text blobs with an LLM first, and then feeding it into the embedding function. Maybe this way we can reduce the number of tokens used.

def zero_shot_classification_with_cosine_similarity(input_text):
    """Use the AI model to predict the classification of the given text."""
    system_prompt = """You are a vulnerability analyst who works to classify vulnerabilities. Given the following vulnerability description, output the name of the most likely classification for the vulnerability, in a format similar to the CWE classifications. For example:

Permissive Regular Expression
Uncontrolled Resource Consumption
Server-Side Request Forgery (SSRF)
Reusing a Nonce, Key Pair in Encryption
Return only the title of the classification."""
    predicted_classification = query_model(system_prompt, input_text)

    query_result = collection.query(
        query_texts=[predicted_classification],
        n_results=1,
        include=["documents", "metadatas"]
    )
    return query_result['ids'][0][0]

total_samples = 10
accurate = test_accuracy(zero_shot_classification_with_cosine_similarity, sample_df, total_samples)
print(f"Accuracy: {accurate}/{total_samples} = {accurate / total_samples:.2%}")

Accuracy: 5/10 = 50.00%


## Hierarchical Multi Step Prompting

CWEs are structured in a hierarchy, so we can use that to our advantage. We can first prompt for the CWE category, then the specific CWE within that category, and so on until we reach the specific CWE. This way, we can reduce the number of tokens in each prompt, use a smaller model, and still get good results.

In [72]:
class CWENode:
    def __init__(self, cwe_id, name):
        self.cwe_id = cwe_id
        self.name = name
        self.description = ""
        self.extended_description = ""
        self.children = []

def load_cwe_tree(file_path):
    with open(file_path, 'r') as file:
        cwe_tree_json = json.load(file)
    
    # Now we need to turn this adjacency list into a tree structure
    # The structure of this tree is a dictionary where each key is a CWE ID and the value is a list of child CWE IDs
    # cwe_tree_json is expected to be a dict: {<cwe_id>: [<children_cwe_ids, ...], ...}
    # The name, description, and extended description are not included in the JSON, so we will need to fetch them from the cwe_df DataFrame.
    # Create all nodes first
    nodes = {}
    for cwe_id in cwe_tree_json.keys():
        # Fetch name, description, and extended description from cwe_df
        cwe_row = cwe_df[cwe_df['CWE-ID'].astype(str) == str(cwe_id)]
        name = cwe_row['Name'].values[0] if not cwe_row.empty else ""
        node = CWENode(cwe_id, name)
        node.description = cwe_row['Description'].values[0] if not cwe_row.empty else ""
        node.extended_description = cwe_row['Extended Description'].values[0] if not cwe_row.empty else ""
        nodes[cwe_id] = node

    # Link children
    for parent_id, children_ids in cwe_tree_json.items():
        parent_node = nodes[parent_id]
        for child_id in children_ids:
            if child_id in nodes:
                parent_node.children.append(nodes[child_id])

    # Find root nodes (those not referenced as children)
    all_children = {child_id for children in cwe_tree_json.values() for child_id in children}
    root_ids = [cwe_id for cwe_id in cwe_tree_json.keys() if cwe_id not in all_children]
    roots = [nodes[cwe_id] for cwe_id in root_ids]
    return roots
    
cwe_tree = load_cwe_tree('data/cwe_tree.json')
print(len(cwe_tree), "root nodes found in the CWE tree.")

10 root nodes found in the CWE tree.
['707', '664', '682', '693', '697', '435', '703', '284', '691', '710']


In [73]:
# The system prompt for the AI model to classify vulnerabilities based on CWE nodes
def generate_hierarchical_system_prompt(cwe_node):
    hierarchical_system_prompt = """You are a vulnerability analyst who works to classify vulnerabilities. Given the following vulnerability descriptions, output the name of the most likely classification for the vulnerability, out of the following categories:
"""
    hierarchical_system_prompt += f"\n{cwe_node.cwe_id} - {cwe_node.name}\n {cwe_node.description}\n---\n"
    for child in cwe_node.children:
        hierarchical_system_prompt += f"{child.cwe_id} - {child.name}\n {child.description}\n---\n"

    hierarchical_system_prompt += """Return only the numerical ID of the classification, do not return any other text."""
    return hierarchical_system_prompt

In [77]:
def hierarchical_system_prompting(input_text):
    # Let's create a new root node for the CWE tree, so we can run a for loop over the root nodes
    # We only care about software development ones, so let's start from 1000
    cwe_root_node = CWENode("CWE-1000", "Root CWE Node, DO NOT USE for classification")
    cwe_root_node.children = cwe_tree

    curr_node = cwe_root_node
    while True:
        if len(curr_node.children) == 0:
            return curr_node.cwe_id

        hierarchical_prompt = generate_hierarchical_system_prompt(curr_node)
        predicted_classification = query_model(hierarchical_prompt, input_text, model="openai/gpt-4.1-mini", temperature=0.95)

        # Check that the predicted classification is a valid integer
        try:
            int(predicted_classification)
        except ValueError:
            return curr_node.cwe_id

        # Check if the predicted classification is the same as the current node's CWE ID
        if predicted_classification == curr_node.cwe_id:
            return curr_node.cwe_id

        # If not, check if the predicted classification is one of the children
        for child in curr_node.children:
            if predicted_classification == child.cwe_id:
                curr_node = child
                break
        else:
            return curr_node.cwe_id

total_samples = 10
accurate = test_accuracy(hierarchical_system_prompting, sample_df, total_samples)
print(f"Accuracy: {accurate}/{total_samples} = {accurate / total_samples:.2%}")

Accuracy: 2/10 = 20.00%
