In [8]:
pip install python-dotenv -q


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip3 install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [None]:
#!/usr/bin/env python3
"""
Optimized VirusTotal collection + MITRE ATT&CK mapping pipeline.

Usage:
  - Put your list of candidate hashes in a CSV (columns: hash, malware_family, source).
  - Set environment variable VIRUSTOTAL_API_KEY (recommended) or use a .env file.
  - Run: python vt_attck_pipeline.py

Outputs:
  - ./output/analysis_results.csv
  - ./output/analysis_results.json
  - ./output/raw_responses/<hash>_file.json
  - ./output/raw_responses/<hash>_behavior.json
  - ./output/attck_navigator_layer.json
"""

import os
import time
import json
import csv
import math
import logging
from typing import List, Dict, Any, Optional
from datetime import datetime
from collections import Counter, defaultdict

import requests
from requests.adapters import HTTPAdapter, Retry
import pandas as pd
from tqdm import tqdm

# Optional: python-dotenv to load .env (if you prefer)
try:
    from dotenv import load_dotenv
    load_dotenv()
except Exception:
    pass

# -----------------------------
# Configuration (edit if needed)
# -----------------------------
API_KEY = os.getenv("VIRUSTOTAL_API_KEY")  # REQUIRED: set this in your environment
if not API_KEY:
    raise SystemExit("ERROR: Set environment variable VIRUSTOTAL_API_KEY before running.")

HEADERS = {"x-apikey": API_KEY}
VT_BASE = "https://www.virustotal.com/api/v3"

OUTPUT_DIR = "./output"
RAW_DIR = os.path.join(OUTPUT_DIR, "raw_responses")
os.makedirs(RAW_DIR, exist_ok=True)

# Rate limiting: Free tier ~4 req/min => ~15s per request; keep conservative default.
DELAY_BETWEEN_REQUESTS = 16.0

# Collection goals (per your grad student requirements)
TARGET_MIN_SAMPLES = 75     # aim for 75 or more
TARGET_MAX_SAMPLES = 100    # cap at 100
MIN_FAMILIES = 6            # >5 families -> set 6 minimum

# File paths
DEFAULT_INPUT_CSV = "./output/hash_signature_output.csv"  # modify if necessary

# Logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

# -----------------------------
# HTTP Session with retries
# -----------------------------
session = requests.Session()
retries = Retry(total=5, backoff_factor=1,
                status_forcelist=[429, 500, 502, 503, 504],
                allowed_methods=["GET", "POST"])
session.mount("https://", HTTPAdapter(max_retries=retries))


# -----------------------------
# Helper functions
# -----------------------------
def read_hash_csv(csv_path: str, hash_col: str = "hash") -> pd.DataFrame:
    """Read CSV containing hashes and family metadata."""
    df = pd.read_csv(csv_path, dtype=str).fillna("")
    if hash_col not in df.columns:
        raise ValueError(f"CSV must contain a '{hash_col}' column. Found: {list(df.columns)}")
    df['hash'] = df['hash'].str.strip().str.lower()
    # normalize family field
    if 'malware_family' not in df.columns:
        df['malware_family'] = "Unknown"
    if 'source' not in df.columns:
        df['source'] = "Unknown"
    df = df[df['hash'] != ""].drop_duplicates(subset=['hash'])
    return df


def select_sample_hashes(df: pd.DataFrame,
                         min_samples: int = TARGET_MIN_SAMPLES,
                         max_samples: int = TARGET_MAX_SAMPLES,
                         min_families: int = MIN_FAMILIES) -> pd.DataFrame:
    """
    Select a balanced set of hashes from df ensuring at least min_families.
    - If dataset doesn't have enough families, function will select as many as available.
    - Attempts to distribute samples across families.
    """
    families = df['malware_family'].value_counts()
    available_families = list(families.index)
    logging.info(f"Available families in input: {len(available_families)}")

    # If there are fewer families than required, warn but continue
    families_to_use = available_families[:]
    if len(families_to_use) < min_families:
        logging.warning(f"Input has only {len(families_to_use)} families; requested {min_families}. "
                        "Proceeding with available families.")

    # Decide final target
    target = min(max(len(df), min_samples), max_samples)
    target = min(target, len(df))  # cannot exceed available rows
    logging.info(f"Target sample count: {target}")

    # Aim to sample evenly across families
    per_family = math.ceil(target / max(1, len(families_to_use)))
    selected = []
    for fam in families_to_use:
        fam_rows = df[df['malware_family'] == fam]
        n = min(per_family, len(fam_rows))
        selected.append(fam_rows.sample(n=n, random_state=42))
        if sum(len(s) for s in selected) >= target:
            break
    sampled_df = pd.concat(selected, ignore_index=True).head(target).reset_index(drop=True)
    logging.info(f"Selected {len(sampled_df)} samples from {sampled_df['malware_family'].nunique()} families")
    return sampled_df


def vt_get(endpoint: str, params: Optional[Dict] = None, max_retries: int = 4) -> Dict[str, Any]:
    """
    Safe GET to VirusTotal with exponential backoff on 429/5xx.
    Returns dict with keys: success(bool), status_code, json (if any), error (if any)
    """
    url = f"{VT_BASE}{endpoint}"
    attempt = 0
    backoff = 1.0
    while attempt <= max_retries:
        try:
            resp = session.get(url, headers=HEADERS, params=params, timeout=30)
            if resp.status_code == 200:
                return {'success': True, 'status_code': 200, 'json': resp.json()}
            if resp.status_code == 404:
                return {'success': False, 'status_code': 404, 'error': 'Not found'}
            if resp.status_code == 429:
                # rate-limited; exponential backoff
                logging.warning(f"429 rate limit from VT; sleeping {backoff:.1f}s (attempt {attempt})")
                time.sleep(backoff)
                backoff *= 2
                attempt += 1
                continue
            if 500 <= resp.status_code < 600:
                logging.warning(f"Server error {resp.status_code}; attempt {attempt}. Backoff {backoff}s")
                time.sleep(backoff)
                backoff *= 2
                attempt += 1
                continue
            # Other errors: return immediately with message
            try:
                j = resp.json()
            except Exception:
                j = None
            return {'success': False, 'status_code': resp.status_code, 'json': j, 'error': f"HTTP {resp.status_code}"}
        except requests.RequestException as e:
            logging.warning(f"Request error: {e}; backing off {backoff}s")
            time.sleep(backoff)
            backoff *= 2
            attempt += 1
    return {'success': False, 'status_code': None, 'error': 'Max retries exceeded'}


def get_file_report(hash_val: str) -> Dict[str, Any]:
    return vt_get(f"/files/{hash_val}")


def get_behavior_report(hash_val: str) -> Dict[str, Any]:
    # behaviour_summary endpoint returns structured behavior summary when present
    return vt_get(f"/files/{hash_val}/behaviour_summary")


def extract_behavioral_indicators(behav_resp: Dict[str, Any]) -> Dict[str, Any]:
    """
    Parse behaviour_summary response into flattened indicators safely.
    Assumes vt_get return structure.
    """
    indicators = {
        'processes_created': [],
        'files_written': [],
        'files_deleted': [],
        'registry_keys_set': [],
        'registry_keys_deleted': [],
        'dns_lookups': [],
        'ip_traffic': [],
        'http_conversations': [],
        'command_executions': [],
        'mutexes_created': [],
        'services_created': [],
        'mitre_techniques': []
    }
    if not behav_resp.get('success'):
        return indicators

    try:
        data = behav_resp.get('json', {}).get('data', {}).get('data', {}) or {}
        attrs = data.get('attributes', {}) or {}

        # copy known lists safely
        indicators['processes_created'] = attrs.get('processes_created', []) or []
        indicators['files_written'] = attrs.get('files_written', []) or []
        indicators['files_deleted'] = attrs.get('files_deleted', []) or []
        indicators['registry_keys_set'] = attrs.get('registry_keys_set', []) or []
        indicators['registry_keys_deleted'] = attrs.get('registry_keys_deleted', []) or []
        indicators['command_executions'] = attrs.get('command_executions', []) or []
        indicators['mutexes_created'] = attrs.get('mutexes_created', []) or []
        indicators['services_created'] = attrs.get('services_created', []) or []

        if 'dns_lookups' in attrs:
            indicators['dns_lookups'] = [d.get('hostname', '') for d in attrs.get('dns_lookups', []) if isinstance(d, dict)]
        if 'ip_traffic' in attrs:
            ip_data = attrs.get('ip_traffic', [])
            indicators['ip_traffic'] = [
                f"{ip.get('destination_ip','')}:{ip.get('destination_port','')}" for ip in ip_data if isinstance(ip, dict)
            ]
        if 'http_conversations' in attrs:
            indicators['http_conversations'] = [h.get('url','') for h in attrs.get('http_conversations', []) if isinstance(h, dict)]
        if 'mitre_attack_techniques' in attrs:
            mt = attrs.get('mitre_attack_techniques', [])
            # ensure each technique is a dict and has id and name fields
            indicators['mitre_techniques'] = [t for t in mt if isinstance(t, dict)]
    except Exception as e:
        logging.error(f"Error parsing behavior response: {e}")
    return indicators


def build_attck_layer(tech_counter: Counter) -> Dict[str, Any]:
    """
    Create a simple ATT&CK Navigator layer JSON using counted technique IDs.
    Each technique will have a "score" equal to its count (normalized optional).
    """
    techniques = []
    max_count = max(tech_counter.values()) if tech_counter else 1
    for tid, count in tech_counter.items():
        techniques.append({
            "techniqueID": tid,
            "tactic": "",  # tactic unknown from behaviour_summary, left blank
            "score": int(count)
        })
    layer = {
        "name": "VT -> MITRE ATT&CK Layer",
        "description": "Techniques extracted from VirusTotal behaviour summaries",
        "domain": "mitre-enterprise",
        "version": "4.3",
        "techniques": techniques
    }
    return layer


# -----------------------------
# Main collection routine
# -----------------------------
def collect_samples(sample_df: pd.DataFrame,
                    delay_between_requests: float = DELAY_BETWEEN_REQUESTS,
                    save_every: int = 10) -> List[Dict[str, Any]]:
    results = []
    tech_counter = Counter()

    # progress bar for user feedback
    for idx, row in enumerate(tqdm(sample_df.to_dict(orient="records"), desc="Collecting samples"), start=1):
        h = row['hash']
        family = row.get('malware_family', 'Unknown')
        source = row.get('source', 'Unknown')
        result = {
            'hash': h,
            'family': family,
            'source': source,
            'status': 'failed',
            'detection_ratio': None,
            'first_seen': None,
            'last_seen': None,
            'processes_count': 0,
            'files_written_count': 0,
            'files_deleted_count': 0,
            'registry_keys_set_count': 0,
            'dns_lookups_count': 0,
            'ip_connections_count': 0,
            'http_requests_count': 0,
            'mutexes_count': 0,
            'mitre_techniques_count': 0,
            'mitre_techniques': '',
            'collected_date': datetime.utcnow().isoformat() + "Z"
        }

        # 1) File report
        fr = get_file_report(h)
        # Save raw file report even if failed
        with open(os.path.join(RAW_DIR, f"{h}_file.json"), "w") as fh:
            json.dump(fr, fh, indent=2, default=str)

        if not fr.get('success'):
            result['error'] = fr.get('error', 'file_report_failed')
            results.append(result)
            # Respect rate limit even after failure
            time.sleep(delay_between_requests)
            continue

        # Parse detection statistics if present
        try:
            attrs = fr['json'].get('data', {}).get('attributes', {}) if fr.get('json') else {}
            stats = attrs.get('last_analysis_stats', {})
            total = sum(stats.values()) if isinstance(stats, dict) else 0
            malicious = stats.get('malicious', 0) if isinstance(stats, dict) else 0
            result['detection_ratio'] = f"{malicious}/{total}" if total else None
            result['first_seen'] = attrs.get('first_submission_date') or None
            result['last_seen'] = attrs.get('last_analysis_date') or None
        except Exception as e:
            logging.debug(f"Error extracting stats for {h}: {e}")

        # Wait to respect rate limits
        time.sleep(delay_between_requests)

        # 2) Behaviour report
        br = get_behavior_report(h)
        with open(os.path.join(RAW_DIR, f"{h}_behavior.json"), "w") as fh:
            json.dump(br, fh, indent=2, default=str)

        indicators = extract_behavioral_indicators(br)
        # fill counts
        result.update({
            'processes_count': len(indicators['processes_created']),
            'files_written_count': len(indicators['files_written']),
            'files_deleted_count': len(indicators['files_deleted']),
            'registry_keys_set_count': len(indicators['registry_keys_set']),
            'dns_lookups_count': len(indicators['dns_lookups']),
            'ip_connections_count': len(indicators['ip_traffic']),
            'http_requests_count': len(indicators['http_conversations']),
            'mutexes_count': len(indicators['mutexes_created']),
            'mitre_techniques_count': len(indicators['mitre_techniques']),
            'mitre_techniques': ", ".join([t.get('id','') for t in indicators['mitre_techniques'] if isinstance(t, dict)])
        })
        # update technique counters
        for t in indicators['mitre_techniques']:
            if isinstance(t, dict) and t.get('id'):
                tech_counter[t['id']] += 1

        result['status'] = 'success'
        results.append(result)

        # Save partial progress regularly
        if idx % save_every == 0 or idx == len(sample_df):
            df_partial = pd.DataFrame(results)
            df_partial.to_csv(os.path.join(OUTPUT_DIR, "analysis_results_partial.csv"), index=False)
            with open(os.path.join(OUTPUT_DIR, "analysis_results_partial.json"), "w") as fh:
                json.dump(results, fh, indent=2)
            logging.info(f"Saved partial results at {idx}/{len(sample_df)}")

        # final wait before next loop (already waited once, but keep conservative)
        if idx < len(sample_df):
            time.sleep(delay_between_requests)

    # Save final outputs
    df_final = pd.DataFrame(results)
    df_final.to_csv(os.path.join(OUTPUT_DIR, "analysis_results.csv"), index=False)
    df_final.to_json(os.path.join(OUTPUT_DIR, "analysis_results.json"), orient="records", indent=2)

    # Build ATT&CK Navigator layer
    layer_json = build_attck_layer(tech_counter)
    with open(os.path.join(OUTPUT_DIR, "attck_navigator_layer.json"), "w") as fh:
        json.dump(layer_json, fh, indent=2)

    logging.info("Collection complete. Outputs saved to ./output")
    return results


# -----------------------------
# Entry point
# -----------------------------
def main(input_csv: str = DEFAULT_INPUT_CSV,
         min_samples: int = TARGET_MIN_SAMPLES,
         max_samples: int = TARGET_MAX_SAMPLES,
         min_families: int = MIN_FAMILIES):
    logging.info("Reading input CSV...")
    df = read_hash_csv(input_csv)
    if df.empty:
        raise SystemExit("No hashes found in input CSV. Exiting.")

    # Build selection
    samples = select_sample_hashes(df, min_samples=min_samples, max_samples=max_samples, min_families=min_families)
    logging.info(f"Beginning collection of {len(samples)} samples (delay {DELAY_BETWEEN_REQUESTS}s between major API calls)")

    # Confirm with console user (non-interactive environments can set CONFIRM=yes env var)
    if os.getenv("CONFIRM", "no").lower() != "yes":
        print(f"About to query VirusTotal for {len(samples)} hashes. This will respect rate-limits ({DELAY_BETWEEN_REQUESTS}s per request).")
        proceed = input("Proceed? (yes/no): ").strip().lower()
        if proceed != "yes":
            logging.info("User aborted.")
            return

    results = collect_samples(samples)
    # Print short summary
    df_res = pd.DataFrame(results)
    success_count = (df_res['status'] == 'success').sum()
    logging.info(f"Finished. Success: {success_count}/{len(df_res)}")
    # Suggest next steps
    logging.info("You can open output/attck_navigator_layer.json in ATT&CK Navigator to visualize technique coverage.")


if __name__ == "__main__":
    main()


2025-11-03 22:32:37,793 - INFO - Reading input CSV...


FileNotFoundError: [Errno 2] No such file or directory: './hash_data/hash_signature_output.csv'