# AI-Vulnerability Assessment

This notebook performs AI-powered vulnerability assessment on container scanner results (Trivy + Grype).

## Step 1: Clone Repository & Install Dependencies

In [None]:
import os
import sys

# Clone the repository if not already present
if not os.path.exists('/content/LLM-Assisted-Container-Security-Analysis'):
    !git clone https://github.com/satyam-thakur/LLM-Assisted-Container-Security-Analysis.git
    print('✓ Repository cloned successfully')
else:
    print('✓ Repository already exists')

# Install required packages
!pip install -q dspy-ai>=2.6.20 pandas ujson python-dotenv pydantic rich requests google-generativeai
print('✓ Dependencies installed')

# Set up paths
repo_root = '/content/LLM-Assisted-Container-Security-Analysis'
os.chdir(repo_root)
if repo_root not in sys.path:
    sys.path.insert(0, repo_root)
print(f'✓ Working directory: {os.getcwd()}')

## Step 2: Configure API Key

In [None]:
from getpass import getpass

# Prompt for API key securely
if 'GEMINI_API_KEY' not in os.environ:
    api_key = getpass('Enter your GEMINI_API_KEY: ')
    os.environ['GEMINI_API_KEY'] = api_key
    print('✓ API key configured')
else:
    print('✓ API key already set')

# Set defaults
os.environ.setdefault('GEMINI_MODEL', 'gemini-1.5-flash')
os.environ.setdefault('LM_TEMPERATURE', '0.2')
print(f'Model: {os.environ["GEMINI_MODEL"]}')

## Step 3: Configuration Module

Handles environment variables and DSPy configuration.

In [None]:
from typing import Optional, Tuple

def get_env(key: str, default: Optional[str] = None) -> Optional[str]:
    return os.environ.get(key, default)

def get_gemini_settings() -> Tuple[Optional[str], str]:
    api_key = get_env("GEMINI_API_KEY") or get_env("GOOGLE_API_KEY")
    model = get_env("GEMINI_MODEL", "gemini-1.5-flash")
    return api_key, model

def configure_dspy() -> str:
    """
    Try to configure DSPy to use Gemini. Returns:
    - "dspy": DSPy configured successfully
    - "fallback": Use direct Gemini API
    """
    api_key, model = get_gemini_settings()
    try:
        import dspy
        lm_name_candidates = [f"google/{model}", f"gemini/{model}", model]
        configured = False
        last_err = None
        for lm_name in lm_name_candidates:
            try:
                dspy.configure(lm=dspy.LM(model=lm_name, api_key=api_key), 
                             temperature=float(get_env("LM_TEMPERATURE", "0.2")))
                configured = True
                break
            except Exception as e:
                last_err = e
                continue
        if not configured:
            if os.getenv("DEBUG_DSPY_SETUP") == "1":
                print(f"[WARN] DSPy Gemini LM configuration failed: {last_err}")
            return "fallback"
        return "dspy"
    except Exception as e:
        if os.getenv("DEBUG_DSPY_SETUP") == "1":
            print(f"[WARN] DSPy not available or config failed: {e}")
        return "fallback"

print('✓ Configuration module loaded')

## Step 4: Prompts & Labels Module

Defines VEX labels and LLM instructions.

In [None]:
from typing import List

LABEL_CHOICES: List[str] = [
    "vulnerable",         # exploitable as-deployed
    "code_not_present",   # package/code not present in image
    "code_not_reachable", # present but not reachable/exposed
    "mitigated",          # present, but mitigations block exploit
    "fixed",              # fixed version present
    "false_positive",     # scanner likely wrong
]

SYSTEM_INSTRUCTIONS = f"""
You are a security expert performing container vulnerability validation.
Given a vulnerability record from scanners (e.g., Trivy/Grype) for a specific container image, decide if the vulnerability is exploitable in the current context.
Return ONLY a strict JSON object with the following keys: affected (boolean), label (one of {LABEL_CHOICES}), reason (<=120 words), risk (low|medium|high), remediation (<=120 words).
Be precise, reduce speculation, and ground your decision in the provided details.
"""

# DSPy Signature definition (optional)
VEXSignature = None
try:
    import dspy
    class VEXSignature(dspy.Signature):
        """Given vulnerability details from scanners, decide exploitability and produce VEX classification.
        Return ONLY strict JSON for: affected, label, reason, risk, remediation.
        """
        cve_id = dspy.InputField(desc="CVE or Vulnerability ID")
        package_name = dspy.InputField()
        installed_version = dspy.InputField()
        fixed_version = dspy.InputField()
        severity = dspy.InputField()
        title = dspy.InputField()
        description = dspy.InputField()
        image = dspy.InputField()
        scanner = dspy.InputField()
        output_json = dspy.OutputField(desc="Strict JSON with keys: affected, label, reason, risk, remediation")
except Exception:
    pass

print('✓ Prompts module loaded')
print(f'Labels: {LABEL_CHOICES}')

## Step 5: Scanner Loader Module

Loads and normalizes scanner JSON output.

In [None]:
import json
from typing import Any, Dict, List, Tuple
import pandas as pd

def load_scanner_results(path: str, limit: Optional[int] = None) -> Dict[str, Any]:
    """Load the combined scanner JSON file (e.g., Trivy + Grype combined).
    
    Args:
        path: Path to the JSON file
        limit: If set, only load first N vulnerabilities (for testing)
    """
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)
    
    # Limit vulnerabilities for testing if requested
    if limit is not None and "vulnerabilities" in data:
        data["vulnerabilities"] = data["vulnerabilities"][:limit]
        data["total_vulnerabilities"] = len(data["vulnerabilities"])
    
    return data

def _norm(v: Any) -> str:
    return "" if v is None else str(v)

def _extract_vuln_id(rec: Dict[str, Any]) -> Tuple[str, str]:
    # Prefer CVE ID if present; else allow GHSA or other IDs
    cve = _norm(rec.get("cve_id"))
    if cve:
        return cve, "cve"
    # Try common alternate fields
    for k in ("vuln_id", "id", "ghsa_id"):
        if _norm(rec.get(k)):
            return _norm(rec.get(k)), k
    # As a last resort, derive from title
    title = _norm(rec.get("title"))
    if title.startswith("CVE-"):
        token = title.split()[0].strip(",.:;")
        return token, "derived"
    return "", "unknown"

def build_vuln_frame(data: Dict[str, Any]) -> pd.DataFrame:
    """Normalize the vulnerability list into a DataFrame."""
    image = _norm(data.get("image"))
    vulns: List[Dict[str, Any]] = data.get("vulnerabilities", []) or []
    rows: List[Dict[str, Any]] = []
    for rec in vulns:
        if not isinstance(rec, dict) or not rec:
            continue
        vuln_id, id_source = _extract_vuln_id(rec)
        if not vuln_id:
            continue
        rows.append({
            "vuln_id": vuln_id,
            "id_source": id_source,
            "package_name": _norm(rec.get("package_name")),
            "installed_version": _norm(rec.get("installed_version")),
            "fixed_version": _norm(rec.get("fixed_version")),
            "severity": _norm(rec.get("severity")),
            "scanner": _norm(rec.get("scanner")),
            "title": _norm(rec.get("title")),
            "description": _norm(rec.get("description")),
            "image": image or _norm(rec.get("image")),
        })
    if not rows:
        return pd.DataFrame(columns=[
            "vuln_id", "id_source", "package_name", "installed_version", "fixed_version",
            "severity", "scanner", "title", "description", "image"
        ])
    df = pd.DataFrame(rows)
    df = df.drop_duplicates()
    return df

print('✓ Scanner loader module loaded')

## Step 6: VEX Reasoner Module

Core AI reasoning engine using DSPy or direct Gemini.

In [None]:
import re
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Tuple

@dataclass
class VEXResult:
    vuln_id: str
    package_name: str
    affected: bool
    label: str
    reason: str
    risk: str
    remediation: str
    raw_model_text: str

_JSON_OBJ_RE = re.compile(r"\{[\s\S]*\}")

def _safe_json_parse(text: str) -> Optional[Dict[str, Any]]:
    # Find first JSON object in the text
    m = _JSON_OBJ_RE.search(text)
    if not m:
        return None
    try:
        return json.loads(m.group(0))
    except Exception:
        return None

def _normalize_label(label: str) -> str:
    low = (label or "").strip().lower()
    for choice in LABEL_CHOICES:
        if low == choice:
            return choice
    # Map common near-misses
    aliases = {
        "not_present": "code_not_present",
        "not_reachable": "code_not_reachable",
        "unexploitable": "mitigated",
        "remediated": "fixed",
        "true_positive": "vulnerable",
        "benign": "false_positive",
    }
    return aliases.get(low, low or "false_positive")

class _DSPyPredictor:
    def __init__(self) -> None:
        import dspy
        self._predict = dspy.Predict(VEXSignature)

    def run_once(self, record: Dict[str, Any]) -> Tuple[str, str]:
        out = self._predict(
            cve_id=record.get("vuln_id", ""),
            package_name=record.get("package_name", ""),
            installed_version=record.get("installed_version", ""),
            fixed_version=record.get("fixed_version", ""),
            severity=record.get("severity", ""),
            title=record.get("title", ""),
            description=record.get("description", ""),
            image=record.get("image", ""),
            scanner=record.get("scanner", ""),
        )
        raw = getattr(out, "output_json", None) or str(out)
        return str(raw), "dspy"

class _GeminiFallback:
    def __init__(self) -> None:
        import google.generativeai as genai
        api_key, model = get_gemini_settings()
        if not api_key:
            raise RuntimeError("GEMINI_API_KEY not configured")
        genai.configure(api_key=api_key)
        self._model = genai.GenerativeModel(model)

    def run_once(self, record: Dict[str, Any]) -> Tuple[str, str]:
        user_prompt = f"""
{SYSTEM_INSTRUCTIONS}

Vulnerability:
- vuln_id: {record.get('vuln_id','')}
- package_name: {record.get('package_name','')}
- installed_version: {record.get('installed_version','')}
- fixed_version: {record.get('fixed_version','')}
- severity: {record.get('severity','')}
- title: {record.get('title','')}
- description: {record.get('description','')}
- image: {record.get('image','')}
- scanner: {record.get('scanner','')}

Return ONLY a strict JSON object.
""".strip()
        resp = self._model.generate_content(user_prompt)
        text = getattr(resp, "text", None) or str(resp)
        return str(text), "gemini"

def _choose_engine() -> Any:
    mode = configure_dspy()
    if mode == "dspy" and VEXSignature is not None:
        try:
            return _DSPyPredictor()
        except Exception:
            pass
    return _GeminiFallback()

def assess_record(engine: Any, rec: Dict[str, Any]) -> VEXResult:
    raw_text, used = engine.run_once(rec)
    parsed = _safe_json_parse(raw_text) or {}
    affected = bool(parsed.get("affected", False))
    label = _normalize_label(str(parsed.get("label", "")))
    reason = str(parsed.get("reason", ""))
    risk = str(parsed.get("risk", "")) or ("high" if rec.get("severity", "").upper() in {"CRITICAL", "HIGH"} else "medium")
    remediation = str(parsed.get("remediation", ""))
    return VEXResult(
        vuln_id=str(rec.get("vuln_id", "")),
        package_name=str(rec.get("package_name", "")),
        affected=affected,
        label=label,
        reason=f"[{used}] " + reason if reason else f"[{used}] no detailed reason",
        risk=risk,
        remediation=remediation,
        raw_model_text=raw_text,
    )

def assess_batch(records: Iterable[Dict[str, Any]]) -> List[VEXResult]:
    engine = _choose_engine()
    results: List[VEXResult] = []
    for rec in records:
        try:
            results.append(assess_record(engine, rec))
        except Exception as e:
            results.append(VEXResult(
                vuln_id=str(rec.get("vuln_id", "")),
                package_name=str(rec.get("package_name", "")),
                affected=False,
                label="error",
                reason=f"engine_error: {e}",
                risk="unknown",
                remediation="",
                raw_model_text="",
            ))
    return results

def to_json(results: List[VEXResult]) -> List[Dict[str, Any]]:
    return [{
        "vuln_id": r.vuln_id,
        "package_name": r.package_name,
        "justification": {
            "affected": r.affected,
            "label": r.label,
            "reason": r.reason,
            "risk": r.risk,
            "remediation": r.remediation,
        },
        "raw_model_text": r.raw_model_text,
    } for r in results]

def to_markdown(results: List[VEXResult]) -> str:
    lines: List[str] = ["# Vulnerability Assessment Report", ""]
    for r in results:
        lines.extend([
            f"## {r.vuln_id} — {r.package_name}",
            f"- Affected: {r.affected}",
            f"- Label: {r.label}",
            f"- Risk: {r.risk}",
            "- Reason:",
            f"  {r.reason}",
            "- Remediation:",
            f"  {r.remediation or 'N/A'}",
            "",
        ])
    return "\n".join(lines)

print('✓ VEX reasoner module loaded')

## Step 7: Main Assessment Function

Orchestrates the full assessment pipeline.

In [None]:
from datetime import datetime

def run_assessment(input_path: str, output_dir: str = '/content/outputs') -> Dict[str, Any]:
    """
    Run vulnerability assessment on scanner JSON file.
    
    Args:
        input_path: Path to combined scanner JSON
        output_dir: Directory to save outputs
    
    Returns:
        Dictionary with json_path, md_path, and summary_df
    """
    # Load and parse scanner data
    print(f'Loading scanner results from: {input_path}')
    data = load_scanner_results(input_path, limit=5)
    df = build_vuln_frame(data)
    
    if df.empty:
        raise RuntimeError("No vulnerabilities found in the provided file.")
    
    print(f'Found {len(df)} vulnerability records')
    
    # De-duplicate to reduce API calls
    df_slice = df.sort_values(["vuln_id", "package_name"]).drop_duplicates([
        "vuln_id", "package_name", "installed_version", "scanner"
    ])
    print(f'Processing {len(df_slice)} unique vulnerabilities...')
    
    # Run assessment
    records = df_slice.to_dict(orient="records")
    results = assess_batch(records)
    
    # Convert to output formats
    json_out = to_json(results)
    md_out = to_markdown(results)
    
    # Save outputs
    os.makedirs(output_dir, exist_ok=True)
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    json_path = os.path.join(output_dir, f"assessment_{ts}.json")
    md_path = os.path.join(output_dir, f"assessment_{ts}.md")
    
    with open(json_path, "w", encoding="utf-8") as f:
        json.dump({
            "input": {
                "image": data.get("image"),
                "scanners_used": data.get("scanners_used"),
                "total_vulnerabilities": data.get("total_vulnerabilities"),
                "source_file": os.path.abspath(input_path),
            },
            "output": json_out,
        }, f, indent=2)
    
    with open(md_path, "w", encoding="utf-8") as f:
        f.write(md_out)
    
    # Create summary DataFrame
    df_out = pd.DataFrame(json_out)
    df_out[["affected", "label", "reason", "risk", "remediation"]] = pd.json_normalize(df_out["justification"])
    df_out.drop(columns=["justification"], inplace=True)
    
    print(f'\n✓ Assessment complete!')
    print(f'  JSON: {json_path}')
    print(f'  Markdown: {md_path}')
    
    return {
        "json_path": json_path,
        "md_path": md_path,
        "summary_df": df_out,
    }

print('✓ Assessment function loaded')

## Step 8: Run Vulnerability Assessment

**Note:** This may take 10-30 minutes for ~650 vulnerabilities.

In [None]:
# Define input file path
INPUT_FILE = '/content/LLM-Assisted-Container-Security-Analysis/Scanner/combined_results/hyperledger_fabric-peer_1.1.0_combined.json'

# Verify file exists
if not os.path.exists(INPUT_FILE):
    raise FileNotFoundError(f'Scanner file not found: {INPUT_FILE}')

print('Starting vulnerability assessment...')
print(f'Input: {INPUT_FILE}\n')

# Run assessment
result = run_assessment(INPUT_FILE, output_dir='/content/outputs')

## Step 9: Display Results Summary

In [None]:
summary_df = result['summary_df']

print('\n=== Vulnerability Assessment Summary ===')
print(f'Total assessed: {len(summary_df)}')
print(f'Affected: {summary_df["affected"].sum()}')
print(f'Not affected: {(~summary_df["affected"]).sum()}')

print('\n=== Label Distribution ===')
print(summary_df['label'].value_counts())

print('\n=== Risk Distribution ===')
print(summary_df['risk'].value_counts())

print('\n=== First 10 Results ===')
display(summary_df[['vuln_id', 'package_name', 'affected', 'label', 'risk']].head(10))

## Step 10: View Detailed Results

In [None]:
# Display full results with reasons and remediation
print('=== Detailed Results ===')
display(summary_df[['vuln_id', 'package_name', 'affected', 'label', 'risk', 'reason', 'remediation']])

## Step 11: Filter Specific Results (Optional)

In [None]:
# Show only affected vulnerabilities
print('=== Affected Vulnerabilities ===')
affected = summary_df[summary_df['affected'] == True]
display(affected[['vuln_id', 'package_name', 'label', 'risk', 'reason']])

# Show high-risk items
print('\n=== High Risk Items ===')
high_risk = summary_df[summary_df['risk'] == 'high']
display(high_risk[['vuln_id', 'package_name', 'affected', 'label', 'reason']])

## Step 12: View JSON Output Structure

In [None]:
# Load and display JSON structure
with open(result['json_path'], 'r', encoding='utf-8') as f:
    json_output = json.load(f)

print('=== JSON Output Structure ===')
print(f"Keys: {list(json_output.keys())}")
print(f"\nInput metadata:")
print(f"  Image: {json_output['input']['image']}")
print(f"  Scanners: {json_output['input']['scanners_used']}")
print(f"  Total vulnerabilities: {json_output['input']['total_vulnerabilities']}")
print(f"\nAssessed: {len(json_output['output'])} vulnerabilities")

# Show example result
print('\n=== Example Result ===')
print(json.dumps(json_output['output'][0], indent=2))

## Step 13: Preview Markdown Report

In [None]:
# Display first 50 lines of markdown report
with open(result['md_path'], 'r', encoding='utf-8') as f:
    md_content = f.read()

lines = md_content.split('\n')
preview_lines = min(50, len(lines))
print(f'=== Markdown Report Preview (first {preview_lines} lines) ===')
print('\n'.join(lines[:preview_lines]))
if len(lines) > 50:
    print(f'\n... ({len(lines) - 50} more lines)')

## Step 14: Generate Statistics

In [None]:
# Generate comprehensive statistics
print('=== Comprehensive Statistics ===')
print(f'\nTotal Vulnerabilities: {len(summary_df)}')
print(f'Unique CVEs: {summary_df["vuln_id"].nunique()}')
print(f'Unique Packages: {summary_df["package_name"].nunique()}')

print('\n--- By Affected Status ---')
print(summary_df['affected'].value_counts())

print('\n--- By Label ---')
for label, count in summary_df['label'].value_counts().items():
    pct = (count / len(summary_df)) * 100
    print(f'{label:20s}: {count:4d} ({pct:5.1f}%)')

print('\n--- By Risk Level ---')
for risk, count in summary_df['risk'].value_counts().items():
    pct = (count / len(summary_df)) * 100
    print(f'{risk:20s}: {count:4d} ({pct:5.1f}%)')

print('\n--- Top 10 Most Vulnerable Packages ---')
vulnerable_packages = summary_df[summary_df['affected'] == True]['package_name'].value_counts().head(10)
for pkg, count in vulnerable_packages.items():
    print(f'{pkg:30s}: {count:2d} vulnerabilities')

## Step 15: Download Results

In [None]:
from google.colab import files

# Download JSON report
print('Downloading JSON report...')
files.download(result['json_path'])

# Download Markdown report
print('Downloading Markdown report...')
files.download(result['md_path'])

print('✓ Downloads initiated')

## Notes & Documentation

### VEX Labels Used
- **vulnerable**: Exploitable as-deployed in this container
- **code_not_present**: Package/code not actually present in image
- **code_not_reachable**: Present but not reachable/exposed at runtime
- **mitigated**: Present but mitigations block exploitation
- **fixed**: Fixed version is present
- **false_positive**: Scanner likely wrong

### Architecture
1. **Configuration**: Manages API keys and DSPy setup
2. **Scanner Loader**: Parses combined Trivy+Grype JSON
3. **Prompts**: Defines system instructions and VEX labels
4. **VEX Reasoner**: Core AI engine (DSPy or direct Gemini)
5. **Assessment**: Orchestrates full pipeline

### Performance
- Processing time: ~10-30 minutes for 650 vulnerabilities
- Rate limits may apply based on your API key tier
- Deduplication reduces API calls significantly

### Troubleshooting
- **API Key Issues**: Verify key is valid and has quota
- **Import Errors**: Re-run Step 1 to reinstall dependencies
- **Rate Limits**: Add delays between calls if needed
- **DSPy Fallback**: Automatically uses direct Gemini if DSPy fails

### Repository
GitHub: [satyam-thakur/LLM-Assisted-Container-Security-Analysis](https://github.com/satyam-thakur/LLM-Assisted-Container-Security-Analysis)

---
*All code is self-contained in this notebook - no external Python files required.*