# Dataset Balancing

In [None]:
from src.paths import PYTHON_VULNERABILITIES_WITHOUT_BALANCING, SYNTHETIC_DATA_PATH, FINAL_VULNERABILITIES_DATA_PATH
import polars as pl

In [60]:
vulnerabilities_without_balancing = pl.read_parquet(PYTHON_VULNERABILITIES_WITHOUT_BALANCING)
vulnerabilities_without_balancing

commit,repo,new_file,patch,code_unit_after_fix,vulnerability_id,cwe_id,old_file,code_unit_before_fix,clustered_cwe_id
str,str,str,str,str,str,list[str],str,str,list[str]
"""ab52630d0644e42a75eb88b78b9a9d…","""python-pillow/Pillow""","""src/libImaging/SgiRleDecode.c""","""@@ -157,6 +157,11 @@ ImagingSg…","""if (c->rleoffset + c->rlelengt…","""2019-16865""","[""CWE-770""]","""src/libImaging/SgiRleDecode.c""","""""","[""CWE-664""]"
"""fe507474f77084faef4539101e2bbb…","""hyperledger/indy-node""","""indy_node/utils/node_control_u…","""@@ -199,9 +199,27 @@ def run_s…","""class NodeControlUtil:  @cl…","""2022-31020""","[""CWE-20"", ""CWE-287""]","""indy_node/utils/node_control_u…","""""","[""CWE-707"", ""CWE-284""]"
"""8d76538d6e105947272b000581c6fa…","""apache/airflow""","""airflow/api_connexion/endpoint…","""@@ -17,25 +17,38 @@  from __fu…","""from typing import TYPE_CHECKI…","""2023-50944""","[""CWE-862""]","""airflow/api_connexion/endpoint…","""from airflow.api_connexion.exc…","[""CWE-284""]"
"""41bd3645bdb616e1248b2167ca8363…","""gradio-app/gradio""","""gradio/networking.py""","""@@ -377,15 +377,14 @@ def inte…","""def file(path):  if (  …","""2021-43831""","[""CWE-22""]","""gradio/networking.py""","""def file(path):  path = sec…","[""CWE-22""]"
"""424c68f15ad9f532d73e5afed33ff4…","""ansible/ansible""","""lib/ansible/plugins/connection…","""@@ -273,11 +273,14 @@ def wrap…","""def wrapped(self, *args, **kwa…","""2018-16876""","[""CWE-200""]","""lib/ansible/plugins/connection…","""def wrapped(self, *args, **kwa…","[""CWE-200""]"
…,…,…,…,…,…,…,…,…,…
"""5ed9478fdef96a06eeec9093f9e768…","""PaddlePaddle/Paddle""","""python/paddle/jit/dy2static/co…","""@@ -693,77 +693,6 @@ def has_n…","""""","""2023-52314""","[""CWE-78""]","""python/paddle/jit/dy2static/co…","""def convert_shape_compare(left…","[""CWE-707""]"
"""2c11575b1a3dd8b0df26a879ba856c…","""thinkst/opencanary""","""opencanary/modules/portscan.py""","""@@ -1,8 +1,8 @@  from opencana…","""from opencanary import safe_ex…","""2024-48911""","[""CWE-863""]","""opencanary/modules/portscan.py""","""import subprocess def detectNF…","[""CWE-284""]"
"""6cde16f3f4711583ae4d896dfcc125…","""NVIDIA/NVFlare""","""nvflare/app_common/abstract/le…","""@@ -13,8 +13,7 @@  # limitatio…","""from nvflare.fuel.utils import…","""2022-34668""","[""CWE-502""]","""nvflare/app_common/abstract/le…","""import pickle class Learnable(…","[""CWE-664""]"
"""73eb03bd75365e112b39877e26ef52…","""ubertidavide/fastbots""","""fastbots/exceptions.py""","""@@ -23,3 +23,17 @@ def __init_…","""class ExpectedUrlError(Generic…","""2023-48699""","[""CWE-94"", ""CWE-95""]","""fastbots/exceptions.py""","""""","[""CWE-707"", ""CWE-664""]"


In [62]:
possible_to_sample_cwes = vulnerabilities_without_balancing.explode("clustered_cwe_id").select(
    "clustered_cwe_id", "vulnerability_id"
).unique(["clustered_cwe_id", "vulnerability_id"]).to_series().value_counts().sort("count", descending=True).filter(pl.col("count") < 200)
possible_to_sample_cwes

clustered_cwe_id,count
str,u32
"""CWE-284""",172
"""CWE-200""",158
"""CWE-79""",153
"""CWE-22""",124
"""CWE-610""",109


In [74]:
short_vulnerabilities = vulnerabilities_without_balancing.select(
    "vulnerability_id", 
    pl.col("code_unit_after_fix").str.len_chars().alias("code_unit_after_fix_len"),
    pl.col("code_unit_before_fix").str.len_chars().alias("code_unit_before_fix_len"),
).group_by(
    "vulnerability_id"
).agg(pl.col("code_unit_after_fix_len").sum(), pl.col("code_unit_before_fix_len").sum()).with_columns(
    total_len=pl.col("code_unit_after_fix_len") + pl.col("code_unit_before_fix_len")
).filter(pl.col("total_len") < 20000).select("vulnerability_id").unique("vulnerability_id")
short_vulnerabilities.describe()

statistic,vulnerability_id
str,str
"""count""","""1062"""
"""null_count""","""0"""
"""mean""",
"""std""",
"""min""","""2013-0270"""
"""25%""",
"""50%""",
"""75%""",
"""max""","""GHSA-x563-6hqv-26mr"""


In [None]:
from pathlib import Path
from uuid import uuid4
from pydantic import BaseModel
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
import json
from collections import defaultdict

BALANCE_AMOUNT = 200

CWE_DESCRIPTION = {
    "CWE-284": "The product does not restrict or incorrectly restricts access to a resource from an unauthorized actor.",
    "CWE-200": "The product exposes sensitive information to an actor that is not explicitly authorized to have access to that information.",
    "CWE-79": "The product does not neutralize or incorrectly neutralizes user-controllable input before it is placed in output that is used as a web page that is served to other users.",
    "CWE-22": "The product uses external input to construct a pathname that is intended to identify a file or directory that is located underneath a restricted parent directory, but the product does not properly neutralize special elements within the pathname that can cause the pathname to resolve to a location that is outside of the restricted directory.",
    "CWE-610": "The product uses an externally controlled name or reference that resolves to a resource that is outside of the intended control sphere.",
}

synthetic_cwes_count = defaultdict(int)
for path in SYNTHETIC_DATA_PATH.glob("*.json"):
    with open(path) as f:
        data = json.load(f)
    synthetic_cwes_count[data["cwe_id"]] += 1

# Define Pydantic models for vulnerability examples.
class VulnerabilityCodePart(BaseModel):
    file: Path
    code: str

class VulnerabilityCodePartExample(BaseModel):
    code_before_vulnerability_fix: VulnerabilityCodePart | None
    code_after_vulnerability_fix: VulnerabilityCodePart | None
    patch: str

class VulnerabilityFullCodeExample(BaseModel):
    vulnerability_code: list[VulnerabilityCodePartExample]

class VulnerabilityExamples(BaseModel):
    vulnerabilities_examples: list[VulnerabilityFullCodeExample]
    cwe_id: str
    cwe_description: str

def generate_new_vulnerability_samples(vuln_examples: VulnerabilityExamples) -> str:
    """
    Uses an OpenAI LLM (via LangChain) to generate new vulnerability samples.
    The provided reference examples are used as guidance so that the generated code
    exhibits a vulnerability of the same CWE type. For each new sample, the output
    includes:
      - 'code_with_vulnerability': Code that intentionally contains the vulnerability. With subfields:
        - 'code': The code snippet.
        - 'file': The file path where the code located.
      - 'code_after_fix': The corrected version of the code. With subfields:
        - 'code': The code snippet.
        - 'file': The file path where the code located.
      - 'patch': A git diff patch showing the changes.
    
    The result is returned as a JSON object with a "vulnerabilities" field.
    """
    # Convert the vulnerability examples to a JSON string for inclusion in the prompt.
    vulnerability_examples_str = vuln_examples.model_dump_json(indent=2)
    
    output_json = """{
    "vulnerabilities": [
        {
            "code_with_vulnerability": {
                "file": "data_processor.py",
                "code": "..."
            },
            "code_after_fix": {
                "file": "data_processor.py",
                "code": "..."
            },
            "patch": "@@..."
        },
        {
            "code_with_vulnerability": {
                "file": "web_server.py",
                "code": "..."
            },
            "code_after_fix": {
                "file": "web_server.py",
                "code": "..."
            },
            "patch": "@@ ..."
        }
    ]
}"""
    
    # Define the prompt template that incorporates the reference vulnerability examples.
    prompt_template = PromptTemplate(
        input_variables=["vulnerability_examples", "cwe_id", "cwe_description"],
        template="""
You are a security expert and software engineer. You have been provided with reference examples of vulnerabilities and their fixes in the following JSON format:
{vulnerability_examples}

These examples illustrate a vulnerability of type {cwe_id}: {cwe_description}.
Using these reference examples as guidance, generate NEW vulnerability samples that intentionally include a vulnerability of the same CWE type ({cwe_id}). 
GENERATE ONLY IN PYTHON LANGUAGE.

### Instructions:

1. **GENERATE A NEW VULNERABILITY SAMPLE**  
   - Write a Python script(s) (`code_with_vulnerability`) that intentionally contains a vulnerability of type **{cwe_id}**.  
   - The nature of the vulnerability must be **identical** to the reference examples.

2. **FIX THE VULNERABILITY**  
   - Provide a corrected version (`code_after_fix`) of the Python code where the vulnerability has been **fully mitigated**.

3. **GENERATE A PATCH**  
   - Create a valid **git diff** patch that shows the exact changes between the vulnerable and fixed code.

Return the result as a JSON object with a "vulnerabilities" field which is a list of objects, each containing the keys "code_with_vulnerability", "code_after_fix", and "patch". 
Make sure that the generated samples are new, distinct from the reference examples, and correctly demonstrate the vulnerability of type {cwe_id}.
Do not generate any comments in source python code. Ensure that the PYTHON code is syntactically correct and follows the conventions of the programming language.

### Output Format:
```json
{output_json}
```

Make sure that output consists of valid JSON format. NO ADDITIONAL TEXT IN OUTPUT.
"""
    )
    
    # Initialize the LLM with a temperature of 0 for deterministic output.
    llm = ChatOpenAI(temperature=0, model="gpt-4o",)
    
    # Create an LLMChain using the prompt.
    chain = LLMChain(llm=llm, prompt=prompt_template)
    
    # Execute the chain with the provided vulnerability examples and related info.
    result = chain.run(
        vulnerability_examples=vulnerability_examples_str,
        cwe_id=vuln_examples.cwe_id,
        cwe_description=vuln_examples.cwe_description,
        output_json=output_json
    )
    return result

        
def sample_cwe_type(cwe_id: str):
    vulnerability_examples = vulnerabilities_without_balancing.join(
        short_vulnerabilities, on="vulnerability_id", how="inner"   
    ).group_by("vulnerability_id").agg(
        "new_file", "old_file", "code_unit_after_fix", "code_unit_before_fix", "clustered_cwe_id", "patch"
    ).explode("clustered_cwe_id").filter(pl.col("clustered_cwe_id").list.contains(cwe_id)).sample(3, shuffle=True).to_dicts()

    vulnerability_examples_input = []
    for vulnerability_example in vulnerability_examples:
        vuln_code_part_examples = []
        for i in range(len(vulnerability_example["new_file"])):
            vuln_code_part_examples.append(
                VulnerabilityCodePartExample(
                    code_before_vulnerability_fix=VulnerabilityCodePart(
                        file=Path(vulnerability_example["old_file"][i]),
                        code=vulnerability_example["code_unit_before_fix"][i]
                    ),
                    code_after_vulnerability_fix=VulnerabilityCodePart(
                        file=Path(vulnerability_example["new_file"][i]),
                        code=vulnerability_example["code_unit_after_fix"][i]
                    ),
                    patch=vulnerability_example["patch"][i],
                )
            )
        vulnerability_examples_input.append(VulnerabilityFullCodeExample(
            vulnerability_code=vuln_code_part_examples
        ))

    example_vulnerabilities = VulnerabilityExamples(
        vulnerabilities_examples=vulnerability_examples_input,
        cwe_id=cwe_id,
        cwe_description=CWE_DESCRIPTION[cwe_id]
    )
        
    example_vulnerabilities.model_dump_json(indent=2)

    # Generate new vulnerability samples that demonstrate the same CWE type.
    new_samples = generate_new_vulnerability_samples(example_vulnerabilities)
    return new_samples[8:-4]

SYNTHETIC_DATA_PATH.mkdir(parents=True, exist_ok=True)
    
for cwe_to_sample in possible_to_sample_cwes.to_dicts():
    cwe_id = cwe_to_sample["clustered_cwe_id"]
    for i in range(cwe_to_sample["count"] + synthetic_cwes_count[cwe_id], BALANCE_AMOUNT + 1):
        print(f"Sampling {cwe_id} {i}")
        data = json.loads(sample_cwe_type(cwe_id))
        data["cwe_id"] = cwe_id
        with open(SYNTHETIC_DATA_PATH / f"{uuid4().hex}.json", "w") as f:
            json.dump(data, f)
        


Sampling CWE-22 138
Sampling CWE-22 139
Sampling CWE-22 140
Sampling CWE-22 141
Sampling CWE-22 142
Sampling CWE-22 143
Sampling CWE-22 144
Sampling CWE-22 145
Sampling CWE-22 146
Sampling CWE-22 147
Sampling CWE-22 148
Sampling CWE-22 149
Sampling CWE-22 150
Sampling CWE-22 151
Sampling CWE-22 152
Sampling CWE-22 153
Sampling CWE-22 154
Sampling CWE-22 155
Sampling CWE-22 156
Sampling CWE-22 157
Sampling CWE-22 158
Sampling CWE-22 159
Sampling CWE-22 160
Sampling CWE-22 161
Sampling CWE-22 162
Sampling CWE-22 163
Sampling CWE-22 164
Sampling CWE-22 165
Sampling CWE-22 166
Sampling CWE-22 167
Sampling CWE-22 168
Sampling CWE-22 169
Sampling CWE-22 170
Sampling CWE-22 171
Sampling CWE-22 172


In [54]:
synthetic_samples = []
for path in SYNTHETIC_DATA_PATH.glob("*.json"):
    with open(path) as f:
        data = json.load(f)
    data["vulnerability_id"] = path.stem
    try:
        data["code_unit_before_fix"] = [x["code_with_vulnerability"]["code"] for x in data["vulnerabilities"]]
        data["code_unit_after_fix"] = [x["code_after_fix"]["code"] for x in data["vulnerabilities"]]
        data["new_file"] = [x["code_after_fix"]["file"] for x in data["vulnerabilities"]]
        data["old_file"] = [x["code_with_vulnerability"]["file"] for x in data["vulnerabilities"]]
        data["patch"] = [x["patch"] for x in data["vulnerabilities"]]
        data["cwe_id"] = [data["cwe_id"]]
        del data["vulnerabilities"]
        if any(file for file in data["new_file"] if not file.endswith(".py")):
            raise ValueError("Not all files are python files")
        synthetic_samples.append(data)
    except Exception as exc:
        print(exc)
        path.unlink()
        

synthetic_samples = pl.DataFrame(synthetic_samples).explode("code_unit_before_fix", "code_unit_after_fix", "new_file", "old_file", "patch")
synthetic_samples.describe()

statistic,cwe_id,vulnerability_id,code_unit_before_fix,code_unit_after_fix,new_file,old_file,patch,cwe_description
str,f64,str,str,str,str,str,str,str
"""count""",537.0,"""537""","""537""","""537""","""537""","""537""","""537""","""39"""
"""null_count""",0.0,"""0""","""0""","""0""","""0""","""0""","""0""","""498"""
"""mean""",,,,,,,,
"""std""",,,,,,,,
"""min""",,"""002ce9a6800d4da5955a8ca6a9a146…","""SESSION_COOKIE_SECURE = False …","""MAX_BODY_SIZE = 1048576 # 1MB…","""action_handler.py""",""" webapp/processing.py""","""@@ -1,10 +1,12 @@  def fetch_u…","""The code does not sufficiently…"
"""25%""",,,,,,,,
"""50%""",,,,,,,,
"""75%""",,,,,,,,
"""max""",,"""fd64e7e56d474fc1bf466b043f50ce…","""import threading from PIL impo…","""import time import random def …","""webapp/utils/file_reader.py""","""webapp/utils/file_reader.py""","""@@ -9,7 +9,7 @@  username …","""The product does not use or in…"


In [56]:
_synthetic_samples = synthetic_samples.with_columns(is_synthetic=pl.lit(True), clustered_cwe_id=pl.col("cwe_id"), commit=pl.lit(None), repo=pl.lit(None)).drop("cwe_description")

final_data = pl.concat([vulnerabilities_without_balancing.with_columns(is_synthetic=pl.lit(False)).select(_synthetic_samples.columns), _synthetic_samples])
final_data.describe()

statistic,cwe_id,vulnerability_id,code_unit_before_fix,code_unit_after_fix,new_file,old_file,patch,is_synthetic,clustered_cwe_id,commit,repo
str,f64,str,str,str,str,str,str,f64,f64,str,str
"""count""",3398.0,"""3398""","""3398""","""3398""","""3398""","""3398""","""3398""",3398.0,3398.0,"""2861""","""2861"""
"""null_count""",0.0,"""0""","""0""","""0""","""0""","""0""","""0""",0.0,0.0,"""537""","""537"""
"""mean""",,,,,,,,0.158034,,,
"""std""",,,,,,,,,,,
"""min""",,"""002ce9a6800d4da5955a8ca6a9a146…","""""","""""","""Adyen/util.py""",""" webapp/processing.py""","""@@ -1 +1 @@ -__version__ = ""1.…",0.0,,"""001b0634cd309e372edb6d7d95d083…","""389ds/389-ds-base"""
"""25%""",,,,,,,,,,,
"""50%""",,,,,,,,,,,
"""75%""",,,,,,,,,,,
"""max""",,"""fd64e7e56d474fc1bf466b043f50ce…","""},  ""scope"": {  …","""} else {  actionsTempla…","""zproject/urls.py""","""zproject/urls.py""","""@@ -99,7 +99,9 @@ extern ""C"" {…",1.0,,"""ffc095a3e5acc1c404773a0510e6d0…","""zwczou/weixin-python"""


In [59]:
final_data.select("vulnerability_id", "clustered_cwe_id").explode("clustered_cwe_id").unique(
    ["vulnerability_id", "clustered_cwe_id"]
).group_by("clustered_cwe_id").agg(pl.count("vulnerability_id"))

clustered_cwe_id,vulnerability_id
str,u32
"""CWE-400""",88
"""CWE-691""",141
"""CWE-22""",109
"""CWE-664""",282
"""CWE-610""",108
…,…
"""CWE-79""",153
"""CWE-703""",157
"""CWE-284""",172
"""CWE-200""",136


In [79]:
final_data.write_parquet(DATA_PATH / "final_vulnerabilities.parquet")