# VirusTotal â†’ MITRE ATT&CK Pipeline (Grad-level)
This notebook collects VirusTotal (VT) file + behaviour summaries for a selected set of malware hashes,
maps behaviours to MITRE ATT&CK techniques, and exports results for analysis and ATT&CK Navigator.

**What you'll get:**
- `output/analysis_results.csv` and `.json` (per-sample indicators and technique lists)
- `output/raw_responses/` (raw VT responses for audit)
- `output/attck_navigator_layer.json` (import into ATT&CK Navigator)

**Grad requirements enforced:**
- â‰¥ 75 samples
- â‰¥ 6 malware families


ðŸ§© Cell 2 â€” Imports 

In [1]:
# Core / stdlib
import os
import time
import json
import math
import logging
from typing import List, Dict, Any, Optional
from datetime import datetime, timezone
from collections import Counter

# Third-party libs
import requests
from requests.adapters import HTTPAdapter, Retry
import pandas as pd
from tqdm import tqdm

# Optional: load .env (so you can store VIRUSTOTAL_API_KEY there)
try:
    from dotenv import load_dotenv
    load_dotenv()
except Exception:
    pass

print("âœ… Imports ready")


âœ… Imports ready


ðŸ§© Cell 3 â€” Configuration (edit here)

In [3]:
# ---- USER-EDITABLE CONFIG ----

# 1) API key: set via env var or .env file (recommended)
API_KEY = os.getenv("VIRUSTOTAL_API_KEY")  # DO NOT hardcode in code you share
if not API_KEY:
    raise SystemExit("ERROR: Set environment variable VIRUSTOTAL_API_KEY or use a .env file.")

# 2) I/O paths
OUTPUT_DIR = "./output"
RAW_DIR = os.path.join(OUTPUT_DIR, "raw_responses")
INPUT_CSV = "./output/hash_signature_output.csv"   # change to your CSV path if needed

# 3) VT API base and headers
VT_BASE = "https://www.virustotal.com/api/v3"
HEADERS = {"x-apikey": API_KEY}

# 4) Rate limits â€” VT Free â‰ˆ 4 req/min â‡’ 15s between requests; stay conservative.
DELAY_BETWEEN_REQUESTS = 16.0

# 5) Assignment targets (Grad)
TARGET_MIN_SAMPLES = 78    # minimum total samples to collect
TARGET_MAX_SAMPLES = 100   # cap to avoid over-collection
MIN_FAMILIES = 8           # at least 6 malware families

# 6) Enforce minimums strictly (exit early if unmet)
HARD_REQUIRE_MINIMUMS = True

# Logging (INFO is good; use DEBUG for even more detail)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

# Ensure output folders exist
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(RAW_DIR, exist_ok=True)

print("âœ… Config ready")


âœ… Config ready


ðŸ§© Cell 4 â€” HTTP Session with Retries

In [4]:
# Create a single requests.Session with retry policy:
# - Retries on 429 (rate limit) and common 5xx errors
# - Exponential backoff between attempts
session = requests.Session()
retries = Retry(
    total=5,
    backoff_factor=1,
    status_forcelist=[429, 500, 502, 503, 504],
    allowed_methods=["GET"]
)
session.mount("https://", HTTPAdapter(max_retries=retries))

print("âœ… HTTP session ready with retries")


âœ… HTTP session ready with retries


ðŸ§© Cell 5 â€” CSV Reader & Basic Validations

In [5]:
def read_hash_csv(csv_path: str, hash_col: str = "hash") -> pd.DataFrame:
    """
    Read hashes + metadata from a CSV.
    Required column: 'hash'
    Optional columns: 'malware_family', 'source'
    - Deduplicates by 'hash'
    - Lowercases and trims hash strings
    """
    df = pd.read_csv(csv_path, dtype=str).fillna("")
    if hash_col not in df.columns:
        raise ValueError(f"CSV must contain a '{hash_col}' column. Found: {list(df.columns)}")
    df['hash'] = df['hash'].str.strip().str.lower()

    if 'malware_family' not in df.columns:
        df['malware_family'] = "Unknown"
    if 'source' not in df.columns:
        df['source'] = "Unknown"

    df = df[df['hash'] != ""].drop_duplicates(subset=['hash'])
    return df


def assert_dataset_meets_requirements(df: pd.DataFrame,
                                      min_samples: int,
                                      min_families: int) -> None:
    """
    Hard stop (or warn) if the dataset does not meet min counts.
    - Unique hashes â‰¥ min_samples
    - Family count â‰¥ min_families
    """
    num_hashes = df['hash'].nunique()
    num_fams = df['malware_family'].replace("", "Unknown").nunique()

    problems = []
    if num_hashes < min_samples:
        problems.append(f"- Only {num_hashes} unique hashes found (need â‰¥ {min_samples}).")
    if num_fams < min_families:
        problems.append(f"- Only {num_fams} malware families found (need â‰¥ {min_families}).")

    if problems and HARD_REQUIRE_MINIMUMS:
        msg = (
            "\nINPUT DATA DOES NOT MEET MINIMUMS:\n"
            + "\n".join(problems)
            + "\n\nTips:\n"
              "  â€¢ Add hashes from MalwareBazaar / Malpedia / ANY.RUN.\n"
              "  â€¢ Ensure 'hash' is valid and 'malware_family' is labeled.\n"
              "  â€¢ Remove duplicates (one row per unique hash).\n"
        )
        raise SystemExit(msg)
    elif problems:
        logging.warning("Proceeding despite dataset not meeting minimums:\n" + "\n".join(problems))

print("âœ… CSV helpers ready")


âœ… CSV helpers ready


ðŸ§© Cell 6 â€” Balanced, Round-Robin Sampling (â‰¥75 samples & â‰¥6 families)

In [6]:
def select_sample_hashes(df: pd.DataFrame,
                         min_samples: int = TARGET_MIN_SAMPLES,
                         max_samples: int = TARGET_MAX_SAMPLES,
                         min_families: int = MIN_FAMILIES) -> pd.DataFrame:
    """
    Build a balanced selection by iterating families in round-robin until we hit the target.
    Ensures (best effort) â‰¥ min_families in the final selection.
    """
    df = df.copy()
    df['malware_family'] = df['malware_family'].replace("", "Unknown")

    # Enforce minimums on the whole dataset
    assert_dataset_meets_requirements(df, min_samples=min_samples, min_families=min_families)

    # Families ordered by frequency (more populous first)
    families_ordered = df['malware_family'].value_counts().index.tolist()
    logging.info(f"Available families in input: {len(families_ordered)}")

    # Target amount (prefer min_samples, never exceed max_samples or dataset size)
    target = min(max(min_samples, min(len(df), max_samples)), max_samples)
    target = min(target, len(df))
    logging.info(f"Target sample count: {target}")

    # Family pools shuffled deterministically for reproducibility
    pools: Dict[str, List[Dict[str, Any]]] = {}
    for fam in families_ordered:
        pools[fam] = df[df['malware_family'] == fam].sample(frac=1.0, random_state=42).to_dict(orient='records')

    # Round-robin pick across families
    selected: List[Dict[str, Any]] = []
    families_rr = families_ordered[:]
    fam_idx = 0
    while len(selected) < target and families_rr:
        fam = families_rr[fam_idx % len(families_rr)]
        if pools[fam]:
            selected.append(pools[fam].pop())
            fam_idx += 1
        else:
            # remove empty families from rotation
            families_rr.pop(fam_idx % len(families_rr))

    sampled_df = pd.DataFrame(selected)

    # Ensure â‰¥ min_families in the final selection (swap in if needed)
    sel_fams = sampled_df['malware_family'].nunique()
    if sel_fams < min_families:
        logging.warning(f"Selection has only {sel_fams} families; trying to add more...")
        remaining_rows = df[~df['hash'].isin(sampled_df['hash'])]
        extras = (remaining_rows.groupby('malware_family').head(1).reset_index(drop=True))
        for _, row in extras.iterrows():
            if row['malware_family'] not in sampled_df['malware_family'].unique():
                if len(sampled_df) < target:
                    sampled_df = pd.concat([sampled_df, pd.DataFrame([row])], ignore_index=True)
                else:
                    # Replace one row from most common family to keep total unchanged
                    over_fam = sampled_df['malware_family'].value_counts().index[0]
                    idx_drop = sampled_df[sampled_df['malware_family'] == over_fam].index[0]
                    sampled_df = sampled_df.drop(index=idx_drop)
                    sampled_df = pd.concat([sampled_df, pd.DataFrame([row])], ignore_index=True)
            if sampled_df['malware_family'].nunique() >= min_families:
                break

    logging.info(f"Selected {len(sampled_df)} samples from {sampled_df['malware_family'].nunique()} families")
    return sampled_df.reset_index(drop=True)

print("âœ… Sampler ready")


âœ… Sampler ready


ðŸ§© Cell 7 â€” VirusTotal GET Helper (with backoff)

In [7]:
def vt_get(endpoint: str, params: Optional[Dict] = None, max_retries: int = 4) -> Dict[str, Any]:
    """
    GET helper with exponential backoff on 429/5xx.
    Returns: {'success': bool, 'status_code': int|None, 'json': dict|None, 'error': str|None}
    """
    url = f"{VT_BASE}{endpoint}"
    attempt = 0
    backoff = 1.0

    while attempt <= max_retries:
        try:
            resp = session.get(url, headers=HEADERS, params=params, timeout=30)
            if resp.status_code == 200:
                return {'success': True, 'status_code': 200, 'json': resp.json()}
            if resp.status_code == 404:
                return {'success': False, 'status_code': 404, 'error': 'Not found'}
            if resp.status_code == 429:
                logging.warning(f"429 rate limit; sleeping {backoff:.1f}s (attempt {attempt})")
                time.sleep(backoff); backoff *= 2; attempt += 1; continue
            if 500 <= resp.status_code < 600:
                logging.warning(f"Server error {resp.status_code}; sleeping {backoff:.1f}s (attempt {attempt})")
                time.sleep(backoff); backoff *= 2; attempt += 1; continue
            # other non-success
            try:
                j = resp.json()
            except Exception:
                j = None
            return {'success': False, 'status_code': resp.status_code, 'json': j, 'error': f"HTTP {resp.status_code}"}
        except requests.RequestException as e:
            logging.warning(f"Request error: {e}; sleeping {backoff:.1f}s (attempt {attempt})")
            time.sleep(backoff); backoff *= 2; attempt += 1

    return {'success': False, 'status_code': None, 'error': 'Max retries exceeded'}


def get_file_report(hash_val: str) -> Dict[str, Any]:
    """VT file metadata (detections, timestamps, etc.)."""
    return vt_get(f"/files/{hash_val}")


def get_behavior_report(hash_val: str) -> Dict[str, Any]:
    """VT sandbox behaviour summary when available."""
    return vt_get(f"/files/{hash_val}/behaviour_summary")

print("âœ… VT API helpers ready")


âœ… VT API helpers ready


ðŸ§© Cell 8 â€” Safe Behaviour Parsing (null-safe)

In [8]:
def extract_behavioral_indicators(behav_resp: Dict[str, Any]) -> Dict[str, Any]:
    """
    Parse /behaviour_summary response safely into flattened indicators.
    Returns empty lists when fields are missing (some hashes have no behaviour).
    """
    indicators = {
        'processes_created': [],
        'files_written': [],
        'files_deleted': [],
        'registry_keys_set': [],
        'registry_keys_deleted': [],
        'dns_lookups': [],
        'ip_traffic': [],
        'http_conversations': [],
        'command_executions': [],
        'mutexes_created': [],
        'services_created': [],
        'mitre_techniques': []
    }

    # If HTTP call failed or no JSON, return empty indicators
    if not behav_resp or not behav_resp.get('success'):
        return indicators

    try:
        j = behav_resp.get('json') or {}
        if not isinstance(j, dict):
            return indicators

        # Expected nesting: json['data']['data']['attributes']
        data1 = j.get('data') or {}
        data2 = data1.get('data') if isinstance(data1, dict) else {}
        attrs = (data2.get('attributes') if isinstance(data2, dict) else {}) or {}
        if not isinstance(attrs, dict):
            return indicators

        # Copy simple lists
        indicators['processes_created'] = attrs.get('processes_created') or []
        indicators['files_written']     = attrs.get('files_written')     or []
        indicators['files_deleted']     = attrs.get('files_deleted')     or []
        indicators['registry_keys_set'] = attrs.get('registry_keys_set') or []
        indicators['registry_keys_deleted'] = attrs.get('registry_keys_deleted') or []
        indicators['command_executions']    = attrs.get('command_executions')    or []
        indicators['mutexes_created']       = attrs.get('mutexes_created')       or []
        indicators['services_created']      = attrs.get('services_created')      or []

        # Normalize structured items
        dns_lookups = attrs.get('dns_lookups') or []
        if isinstance(dns_lookups, list):
            indicators['dns_lookups'] = [d.get('hostname', '') for d in dns_lookups if isinstance(d, dict)]

        ip_traffic = attrs.get('ip_traffic') or []
        if isinstance(ip_traffic, list):
            indicators['ip_traffic'] = [
                f"{ip.get('destination_ip','')}:{ip.get('destination_port','')}"
                for ip in ip_traffic if isinstance(ip, dict)
            ]

        http_conversations = attrs.get('http_conversations') or []
        if isinstance(http_conversations, list):
            indicators['http_conversations'] = [h.get('url','') for h in http_conversations if isinstance(h, dict)]

        mt = attrs.get('mitre_attack_techniques') or []
        if isinstance(mt, list):
            indicators['mitre_techniques'] = [t for t in mt if isinstance(t, dict)]

    except Exception as e:
        logging.error(f"Error parsing behavior response safely: {e}")

    return indicators

print("âœ… Behaviour parser ready")


âœ… Behaviour parser ready


ðŸ§© Cell 9 â€” Build ATT&CK Navigator Layer

In [9]:
def build_attck_layer(tech_counter: Counter) -> Dict[str, Any]:
    """
    Build a minimal ATT&CK Navigator layer.
    - technique.score = occurrence count across all samples
    - tactic left blank (behaviour_summary may not map tactics)
    """
    techniques = [
        {"techniqueID": tid, "tactic": "", "score": int(count)}
        for tid, count in tech_counter.items()
    ]
    return {
        "name": "VT â†’ MITRE ATT&CK Layer",
        "description": "Techniques extracted from VirusTotal behaviour summaries",
        "domain": "mitre-enterprise",
        "version": "4.3",
        "techniques": techniques
    }

print("âœ… ATT&CK layer builder ready")


âœ… ATT&CK layer builder ready


ðŸ§© Cell 10 â€” Collection Loop (rate-limited)

In [10]:
def collect_samples(sample_df: pd.DataFrame,
                    delay_between_requests: float = DELAY_BETWEEN_REQUESTS,
                    save_every: int = 10) -> List[Dict[str, Any]]:
    """
    For each hash:
      1) GET file report (save raw)
      2) sleep
      3) GET behaviour summary (save raw)
      4) parse indicators â†’ row
      5) rate-limit sleeps between samples
    Saves partial results every `save_every` rows.
    """
    results = []
    tech_counter = Counter()

    for idx, row in enumerate(tqdm(sample_df.to_dict(orient="records"), desc="Collecting samples"), start=1):
        h = row['hash']
        family = row.get('malware_family', 'Unknown')
        source = row.get('source', 'Unknown')

        # Initialize result row with defaults
        result = {
            'hash': h,
            'family': family,
            'source': source,
            'status': 'failed',
            'detection_ratio': None,
            'first_seen': None,
            'last_seen': None,
            'processes_count': 0,
            'files_written_count': 0,
            'files_deleted_count': 0,
            'registry_keys_set_count': 0,
            'dns_lookups_count': 0,
            'ip_connections_count': 0,
            'http_requests_count': 0,
            'mutexes_count': 0,
            'mitre_techniques_count': 0,
            'mitre_techniques': '',
            'collected_date': datetime.now(timezone.utc).isoformat(),
        }

        # ---- 1) File report
        fr = get_file_report(h)
        with open(os.path.join(RAW_DIR, f"{h}_file.json"), "w") as fh:
            json.dump(fr, fh, indent=2, default=str)

        if not fr.get('success'):
            # Record failure and continue (still sleep to respect rate limits)
            result['error'] = fr.get('error', 'file_report_failed')
            results.append(result)
            time.sleep(delay_between_requests)
            continue

        # Parse some fields (safe access)
        try:
            attrs = fr['json'].get('data', {}).get('attributes', {}) if fr.get('json') else {}
            stats = attrs.get('last_analysis_stats', {})
            total = sum(stats.values()) if isinstance(stats, dict) else 0
            malicious = stats.get('malicious', 0) if isinstance(stats, dict) else 0
            result['detection_ratio'] = f"{malicious}/{total}" if total else None
            result['first_seen'] = attrs.get('first_submission_date') or None
            result['last_seen'] = attrs.get('last_analysis_date') or None
        except Exception as e:
            logging.debug(f"Error extracting stats for {h}: {e}")

        # ---- rate-limit sleep between major API calls
        time.sleep(delay_between_requests)

        # ---- 2) Behaviour summary
        br = get_behavior_report(h)
        with open(os.path.join(RAW_DIR, f"{h}_behavior.json"), "w") as fh:
            json.dump(br, fh, indent=2, default=str)

        # Parse behaviour if success; else keep empties
        if not br.get('success'):
            indicators = {
                'processes_created': [], 'files_written': [], 'files_deleted': [],
                'registry_keys_set': [], 'registry_keys_deleted': [], 'dns_lookups': [],
                'ip_traffic': [], 'http_conversations': [], 'command_executions': [],
                'mutexes_created': [], 'services_created': [], 'mitre_techniques': []
            }
        else:
            indicators = extract_behavioral_indicators(br)

        # Update counts and technique list
        result.update({
            'processes_count': len(indicators['processes_created']),
            'files_written_count': len(indicators['files_written']),
            'files_deleted_count': len(indicators['files_deleted']),
            'registry_keys_set_count': len(indicators['registry_keys_set']),
            'dns_lookups_count': len(indicators['dns_lookups']),
            'ip_connections_count': len(indicators['ip_traffic']),
            'http_requests_count': len(indicators['http_conversations']),
            'mutexes_count': len(indicators['mutexes_created']),
            'mitre_techniques_count': len(indicators['mitre_techniques']),
            'mitre_techniques': ", ".join([t.get('id','') for t in indicators['mitre_techniques'] if isinstance(t, dict)])
        })

        # Count techniques globally (for Navigator layer)
        for t in indicators['mitre_techniques']:
            if isinstance(t, dict) and t.get('id'):
                tech_counter[t['id']] += 1

        # Mark row success and append
        result['status'] = 'success'
        results.append(result)

        # Periodic save of partial progress
        if idx % save_every == 0 or idx == len(sample_df):
            df_partial = pd.DataFrame(results)
            df_partial.to_csv(os.path.join(OUTPUT_DIR, "analysis_results_partial.csv"), index=False)
            with open(os.path.join(OUTPUT_DIR, "analysis_results_partial.json"), "w") as fh:
                json.dump(results, fh, indent=2)
            logging.info(f"Saved partial results at {idx}/{len(sample_df)}")

        # ---- rate-limit sleep between samples
        if idx < len(sample_df):
            time.sleep(delay_between_requests)

    # Save final, full outputs
    df_final = pd.DataFrame(results)
    df_final.to_csv(os.path.join(OUTPUT_DIR, "analysis_results.csv"), index=False)
    df_final.to_json(os.path.join(OUTPUT_DIR, "analysis_results.json"), orient="records", indent=2)

    # Build ATT&CK Navigator layer
    layer_json = build_attck_layer(tech_counter)
    with open(os.path.join(OUTPUT_DIR, "attck_navigator_layer.json"), "w") as fh:
        json.dump(layer_json, fh, indent=2)

    logging.info("Collection complete. Outputs saved to ./output")
    return results

print("âœ… Collector ready")


âœ… Collector ready


ðŸ§© Cell 11 â€” Main Driver (select + estimate + run)

In [11]:
def main(input_csv: str = INPUT_CSV,
         min_samples: int = TARGET_MIN_SAMPLES,
         max_samples: int = TARGET_MAX_SAMPLES,
         min_families: int = MIN_FAMILIES):
    # 1) Read input hashes
    logging.info("Reading input CSV...")
    df = read_hash_csv(input_csv)
    if df.empty:
        raise SystemExit("No hashes found in input CSV. Exiting.")

    # 2) Enforce assignment minimums
    assert_dataset_meets_requirements(df, min_samples=min_samples, min_families=min_families)

    # 3) Build a balanced selection
    samples = select_sample_hashes(df, min_samples=min_samples, max_samples=max_samples, min_families=min_families)
    logging.info(f"Beginning collection of {len(samples)} samples (delay {DELAY_BETWEEN_REQUESTS}s between major API calls)")

    # 4) Print planned sleep time (excludes network/IO/retries)
    n = len(samples)
    planned_sleeps = (2 * n) - 1
    planned_wait_seconds = planned_sleeps * DELAY_BETWEEN_REQUESTS
    print(f"Planned sleeps: {planned_sleeps} Ã— {DELAY_BETWEEN_REQUESTS:.1f}s "
          f"= {int(planned_wait_seconds)}s ({planned_wait_seconds/60:.1f} minutes) "
          "(excludes network/retries/IO).")

    # 5) Confirm run (set CONFIRM=yes to skip)
    if os.getenv("CONFIRM", "no").lower() != "yes":
        print(f"About to query VirusTotal for {len(samples)} hashes. This respects rate-limits "
              f"({DELAY_BETWEEN_REQUESTS}s per major call).")
        proceed = input("Proceed? (yes/no): ").strip().lower()
        if proceed != "yes":
            logging.info("User aborted.")
            return

    # 6) Collect + save outputs
    results = collect_samples(samples)

    # 7) Short summary
    df_res = pd.DataFrame(results)
    success_count = (df_res['status'] == 'success').sum()
    logging.info(f"Finished. Success: {success_count}/{len(df_res)}")
    logging.info("Open output/attck_navigator_layer.json in ATT&CK Navigator to visualize technique coverage.")

print("âœ… Main ready")


âœ… Main ready


ðŸ§© Cell 12 â€” Run

In [12]:
# Run the main flow. You can override INPUT_CSV or min/max here if needed.
main(
    input_csv=INPUT_CSV,
    min_samples=TARGET_MIN_SAMPLES,
    max_samples=TARGET_MAX_SAMPLES,
    min_families=MIN_FAMILIES
)


2025-11-03 23:17:51,769 - INFO - Reading input CSV...
2025-11-03 23:17:51,792 - INFO - Available families in input: 23
2025-11-03 23:17:51,793 - INFO - Target sample count: 100
2025-11-03 23:17:51,808 - INFO - Selected 100 samples from 23 families
2025-11-03 23:17:51,809 - INFO - Beginning collection of 100 samples (delay 16.0s between major API calls)


Planned sleeps: 199 Ã— 16.0s = 3184s (53.1 minutes) (excludes network/retries/IO).
About to query VirusTotal for 100 hashes. This respects rate-limits (16.0s per major call).


Collecting samples:   9%|â–‰         | 9/100 [04:54<49:38, 32.73s/it]2025-11-03 23:23:08,364 - INFO - Saved partial results at 10/100
Collecting samples:  19%|â–ˆâ–‰        | 19/100 [10:23<44:30, 32.97s/it]2025-11-03 23:28:36,974 - INFO - Saved partial results at 20/100
Collecting samples:  29%|â–ˆâ–ˆâ–‰       | 29/100 [16:04<41:13, 34.83s/it]2025-11-03 23:34:17,631 - INFO - Saved partial results at 30/100
Collecting samples:  39%|â–ˆâ–ˆâ–ˆâ–‰      | 39/100 [21:32<33:23, 32.84s/it]2025-11-03 23:39:45,719 - INFO - Saved partial results at 40/100
Collecting samples:  49%|â–ˆâ–ˆâ–ˆâ–ˆâ–‰     | 49/100 [27:01<28:14, 33.23s/it]2025-11-03 23:45:15,473 - INFO - Saved partial results at 50/100
Collecting samples:  59%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰    | 59/100 [32:28<22:15, 32.58s/it]2025-11-03 23:50:41,543 - INFO - Saved partial results at 60/100
Collecting samples:  69%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰   | 69/100 [40:02<19:13, 37.20s/it]2025-11-03 23:58:16,674 - INFO - Saved partial results at 70/100
Collecting sam