In [None]:
"""
BRON Dataset to Neo4j Loader

This script downloads the BRON (Linked Threat Data) dataset, transforms it into
a Neo4j-compatible format, and uploads it to a Neo4j database.

BRON links threat intelligence data including:
- CVE (Common Vulnerabilities and Exposures)
- CWE (Common Weakness Enumeration)
- CAPEC (Common Attack Pattern Enumeration and Classification)
- ATT&CK Tactics and Techniques

Source: https://github.com/ALFA-group/BRON
"""

import json
import os
from pathlib import Path

import requests
from neo4j import GraphDatabase

In [None]:
# Configuration
BRON_URL = "https://github.com/ALFA-group/BRON/raw/refs/heads/master/graph_db/example_data/BRON.json"
DATA_DIR = Path("data")
BRON_FILE = DATA_DIR / "BRON.json"

# Neo4j connection settings (set via environment variables)
NEO4J_URI = os.environ.get("NEO4J_URI", "bolt://localhost:7687")
NEO4J_USER = os.environ.get("NEO4J_USER", "neo4j")
NEO4J_PASSWORD = os.environ.get("NEO4J_PASSWORD", "")
NEO4J_DATABASE = os.environ.get("NEO4J_DATABASE", "neo4j")

if not NEO4J_PASSWORD:
    raise ValueError(
        "NEO4J_PASSWORD environment variable is required. "
        "Set it before running this script."
    )

In [None]:
def download_bron_dataset(url: str, output_path: Path) -> dict:
    """
    Download the BRON dataset from GitHub if not already cached locally.
    
    Args:
        url: URL to the BRON.json file
        output_path: Local path to save the downloaded file
        
    Returns:
        Parsed JSON data as a dictionary
    """
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    if output_path.exists():
        print(f"Using cached dataset: {output_path}")
    else:
        print(f"Downloading BRON dataset from {url}...")
        response = requests.get(url, timeout=60)
        response.raise_for_status()
        output_path.write_bytes(response.content)
        print(f"Saved to {output_path}")
    
    with open(output_path) as f:
        return json.load(f)

# Download and load the dataset
data = download_bron_dataset(BRON_URL, BRON_FILE)
raw_nodes = data.get("nodes", [])
raw_edges = data.get("edges", data.get("links", []))

print(f"Loaded {len(raw_nodes)} nodes and {len(raw_edges)} edges")

In [None]:
def _neo4j_value(v):
    """
    Convert Python values into Neo4j-compatible property values.
    
    Neo4j accepts primitives and lists of primitives. Complex nested
    structures are JSON-encoded as strings.
    """
    if isinstance(v, (str, int, float, bool)) or v is None:
        return v
    if isinstance(v, list):
        return [_neo4j_value(x) for x in v]
    return json.dumps(v)


def _flatten_props(obj: dict, parent_key: str = "", sep: str = "_") -> dict:
    """
    Flatten nested dicts/lists into a single-level dict with compound keys.
    
    Examples:
        {'a': {'b': 1}} -> {'a_b': 1}
        {'list': [{'x':1}, {'x':2}]} -> {'list_0_x':1, 'list_1_x':2}
        {'vals': [1,2]} -> {'vals': [1,2]}  (kept as list of primitives)
    """
    items = {}
    if not isinstance(obj, dict):
        return items

    for k, v in obj.items():
        new_key = f"{parent_key}{sep}{k}" if parent_key else k

        if isinstance(v, dict):
            items.update(_flatten_props(v, new_key, sep=sep))
        elif isinstance(v, list):
            if all(isinstance(x, (str, int, float, bool)) or x is None for x in v):
                items[new_key] = v
            else:
                for i, elem in enumerate(v):
                    idx_key = f"{new_key}{sep}{i}"
                    if isinstance(elem, dict):
                        items.update(_flatten_props(elem, idx_key, sep=sep))
                    elif isinstance(elem, list):
                        if all(isinstance(x, (str, int, float, bool)) or x is None for x in elem):
                            items[idx_key] = elem
                        else:
                            items[idx_key] = json.dumps(elem)
                    else:
                        items[idx_key] = elem
        else:
            items[new_key] = v

    return items


def sanitize_props(props: dict) -> dict:
    """Flatten nested properties and convert values to Neo4j-safe types."""
    flat = _flatten_props(props or {})
    return {k: _neo4j_value(v) for k, v in flat.items()}

In [None]:
def parse_node(idx: int, n) -> dict:
    """
    Parse a BRON node entry into Neo4j-ready format.
    
    Args:
        idx: Index of the node (used as fallback ID)
        n: Raw node data from BRON (list or dict format)
        
    Returns:
        Dict with keys: id, label, attributes
    """
    bron_key = f"node-{idx}"
    meta = {}

    if isinstance(n, list) and len(n) >= 2:
        bron_key = str(n[0])
        if isinstance(n[1], dict):
            meta = n[1]
    elif isinstance(n, dict):
        bron_key = str(n.get("id", bron_key))
        meta = n
    
    original_id = meta.get("original_id", bron_key)
    datatype = meta.get("datatype") or meta.get("type")
    name = meta.get("name") or original_id
    metadata = meta.get("metadata", {})
    
    attributes = dict(metadata)
    attributes["name"] = name
    attributes["original_id"] = original_id

    return {
        "id": bron_key,
        "label": datatype,
        "attributes": sanitize_props(attributes),
    }


# Relationship type mapping based on source/target node types
EDGE_TYPE_MAP = {
    ("capec", "cwe"): "EXPLOITS_WEAKNESS",
    ("cwe", "cve"): "MANIFESTS_AS",
    ("tactic", "technique"): "INCLUDES_TECHNIQUE",
    ("technique", "capec"): "IMPLEMENTS_PATTERN",
}


def parse_edge(edge) -> dict | None:
    """
    Parse a BRON edge into Neo4j-ready format.
    
    Args:
        edge: Raw edge data from BRON (list format: [source, target, props?])
        
    Returns:
        Dict with keys: source, target, type, attributes (or None if invalid)
    """
    if not isinstance(edge, list) or len(edge) < 2:
        return None
    
    source_key = edge[0]
    target_key = edge[1]
    edge_props = edge[2] if len(edge) > 2 and isinstance(edge[2], dict) else {}
    
    source_type = source_key.split('_')[0]
    target_type = target_key.split('_')[0]
    edge_type = EDGE_TYPE_MAP.get((source_type, target_type), "RELATES_TO")
    
    return {
        "source": source_key,
        "target": target_key,
        "type": edge_type,
        "attributes": sanitize_props(edge_props),
    }

In [None]:
# Transform all nodes and edges
nodes = [parse_node(idx, n) for idx, n in enumerate(raw_nodes)]
edges = [e for e in (parse_edge(edge) for edge in raw_edges) if e is not None]

print(f"Parsed {len(nodes)} nodes and {len(edges)} edges")

# Show node type distribution
label_counts = {}
for node in nodes:
    label = node["label"]
    label_counts[label] = label_counts.get(label, 0) + 1

print("\nNode types:")
for label, count in sorted(label_counts.items()):
    print(f"  {label}: {count}")

In [None]:
def batch_list(lst: list, batch_size: int = 1000):
    """Split a list into batches of specified size."""
    for i in range(0, len(lst), batch_size):
        yield lst[i:i + batch_size]


def load_nodes(tx, nodes_batch: list):
    """Load a batch of nodes into Neo4j using APOC merge."""
    query = """
    UNWIND $nodes AS node
    CALL apoc.merge.node([node.label], {id: node.id}, node.attributes) YIELD node AS n
    RETURN count(n)
    """
    tx.run(query, nodes=nodes_batch)


def load_edges(tx, edges_batch: list):
    """Load a batch of edges into Neo4j using APOC merge."""
    query = """
    UNWIND $edges AS edge
    MATCH (source {id: edge.source})
    MATCH (target {id: edge.target})
    CALL apoc.merge.relationship(source, edge.type, {}, edge.attributes, target) YIELD rel
    RETURN count(rel)
    """
    tx.run(query, edges=edges_batch)


def load_to_neo4j(nodes: list, edges: list, batch_size: int = 1000):
    """
    Load nodes and edges into Neo4j database.
    
    Requires APOC plugin to be installed on the Neo4j server.
    """
    driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
    
    try:
        with driver.session(database=NEO4J_DATABASE) as session:
            print("Loading nodes...")
            for i, batch in enumerate(batch_list(nodes, batch_size)):
                session.execute_write(load_nodes, batch)
                if (i + 1) % 10 == 0:
                    print(f"  Loaded {(i + 1) * batch_size} nodes...")
            print(f"  Completed: {len(nodes)} nodes")
            
            print("Loading edges...")
            for i, batch in enumerate(batch_list(edges, batch_size)):
                session.execute_write(load_edges, batch)
                if (i + 1) % 10 == 0:
                    print(f"  Loaded {(i + 1) * batch_size} edges...")
            print(f"  Completed: {len(edges)} edges")
    finally:
        driver.close()
    
    print("Done!")


# Load data into Neo4j
load_to_neo4j(nodes, edges)