# Offline Behaviour Fetcher
Reuse the load already-saved VirusTotal responses from `output/raw_responses` instead of calling the API.

In [9]:
# Core / stdlib
import os
import json
import logging
from typing import List, Dict, Any
from datetime import datetime, timezone
from collections import Counter

# Third-party libs
import pandas as pd
from tqdm import tqdm

print("✅ Imports ready (offline cache mode)")


✅ Imports ready (offline cache mode)


In [4]:
# ---- OFFLINE CONFIG ----

OUTPUT_DIR = "./output"
RAW_DIR = os.path.join(OUTPUT_DIR, "raw_responses")
INPUT_CSV = "./output/hash_signature_output.csv"

TARGET_MIN_SAMPLES = 78    # minimum total samples to collect
TARGET_MAX_SAMPLES = 100   # cap to avoid over-collection
MIN_FAMILIES = 8           # at least 8 malware families

HARD_REQUIRE_MINIMUMS = True

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

os.makedirs(OUTPUT_DIR, exist_ok=True)
if not os.path.isdir(RAW_DIR):
    raise SystemExit(f"Expected cached responses in {RAW_DIR}.")

print("✅ Config ready (reading cached responses)")


✅ Config ready (reading cached responses)


In [7]:
def read_hash_csv(csv_path: str, hash_col: str = "hash") -> pd.DataFrame:
    """
    Read hashes + metadata from a CSV.
    Required column: 'hash'
    Optional columns: 'malware_family', 'source'
    - Deduplicates by 'hash'
    - Lowercases and trims hash strings
    """
    df = pd.read_csv(csv_path, dtype=str).fillna("")
    if hash_col not in df.columns:
        raise ValueError(f"CSV must contain a '{hash_col}' column. Found: {list(df.columns)}")
    df['hash'] = df['hash'].str.strip().str.lower()

    if 'malware_family' not in df.columns:
        df['malware_family'] = "Unknown"
    if 'source' not in df.columns:
        df['source'] = "Unknown"

    df = df[df['hash'] != ""].drop_duplicates(subset=['hash'])
    return df


def assert_dataset_meets_requirements(df: pd.DataFrame,
                                      min_samples: int,
                                      min_families: int) -> None:
    """
    Hard stop (or warn) if the dataset does not meet min counts.
    - Unique hashes ≥ min_samples
    - Family count ≥ min_families
    """
    num_hashes = df['hash'].nunique()
    num_fams = df['malware_family'].replace("", "Unknown").nunique()

    problems = []
    if num_hashes < min_samples:
        problems.append(f"- Only {num_hashes} unique hashes found (need ≥ {min_samples}).")
    if num_fams < min_families:
        problems.append(f"- Only {num_fams} malware families found (need ≥ {min_families}).")

    if problems and HARD_REQUIRE_MINIMUMS:
        msg = (
            "\nINPUT DATA DOES NOT MEET MINIMUMS:\n"
            + "\n".join(problems)
        )
        raise SystemExit(msg)
    elif problems:
        logging.warning("Proceeding despite dataset not meeting minimums:\n" + "\n".join(problems))

print("✅ CSV helpers ready")


✅ CSV helpers ready


In [8]:
def select_sample_hashes(df: pd.DataFrame,
                         min_samples: int = TARGET_MIN_SAMPLES,
                         max_samples: int = TARGET_MAX_SAMPLES,
                         min_families: int = MIN_FAMILIES) -> pd.DataFrame:
    """
    Build a balanced selection by iterating families in round-robin until we hit the target.
    Ensures (best effort) ≥ min_families in the final selection.
    """
    df = df.copy()
    df['malware_family'] = df['malware_family'].replace("", "Unknown")

    # Enforce minimums on the whole dataset
    assert_dataset_meets_requirements(df, min_samples=min_samples, min_families=min_families)

    # Families ordered by frequency (more populous first)
    families_ordered = df['malware_family'].value_counts().index.tolist()
    logging.info(f"Available families in input: {len(families_ordered)}")

    # Target amount (prefer min_samples, never exceed max_samples or dataset size)
    target = min(max(min_samples, min(len(df), max_samples)), max_samples)
    target = min(target, len(df))
    logging.info(f"Target sample count: {target}")

    # Family pools shuffled deterministically for reproducibility
    pools: Dict[str, List[Dict[str, Any]]] = {}
    for fam in families_ordered:
        pools[fam] = df[df['malware_family'] == fam].sample(frac=1.0, random_state=42).to_dict(orient='records')

    # Round-robin pick across families
    selected: List[Dict[str, Any]] = []
    families_rr = families_ordered[:]
    fam_idx = 0
    while len(selected) < target and families_rr:
        fam = families_rr[fam_idx % len(families_rr)]
        if pools[fam]:
            selected.append(pools[fam].pop())
            fam_idx += 1
        else:
            # remove empty families from rotation
            families_rr.pop(fam_idx % len(families_rr))

    sampled_df = pd.DataFrame(selected)

    # Ensure ≥ min_families in the final selection (swap in if needed)
    sel_fams = sampled_df['malware_family'].nunique()
    if sel_fams < min_families:
        logging.warning(f"Selection has only {sel_fams} families; trying to add more...")
        remaining_rows = df[~df['hash'].isin(sampled_df['hash'])]
        extras = (remaining_rows.groupby('malware_family').head(1).reset_index(drop=True))
        for _, row in extras.iterrows():
            if row['malware_family'] not in sampled_df['malware_family'].unique():
                if len(sampled_df) < target:
                    sampled_df = pd.concat([sampled_df, pd.DataFrame([row])], ignore_index=True)
                else:
                    # Replace one row from most common family to keep total unchanged
                    over_fam = sampled_df['malware_family'].value_counts().index[0]
                    idx_drop = sampled_df[sampled_df['malware_family'] == over_fam].index[0]
                    sampled_df = sampled_df.drop(index=idx_drop)
                    sampled_df = pd.concat([sampled_df, pd.DataFrame([row])], ignore_index=True)
            if sampled_df['malware_family'].nunique() >= min_families:
                break

    logging.info(f"Selected {len(sampled_df)} samples from {sampled_df['malware_family'].nunique()} families")
    return sampled_df.reset_index(drop=True)

print("✅ Sampler ready")


✅ Sampler ready


In [9]:
def load_cached_response(hash_val: str, suffix: str) -> Dict[str, Any]:
    """Load <hash>_<suffix>.json from RAW_DIR and wrap it like the API helper did."""
    path = os.path.join(RAW_DIR, f"{hash_val}_{suffix}.json")
    if not os.path.exists(path):
        return {'success': False, 'status_code': None, 'json': None, 'error': f"missing cached {suffix} response", 'path': path}
    try:
        with open(path, 'r') as fh:
            data = json.load(fh)
    except Exception as exc:
        return {'success': False, 'status_code': None, 'json': None, 'error': f"failed to read cached {suffix} response: {exc}", 'path': path}

    if isinstance(data, dict) and 'success' in data:
        data.setdefault('path', path)
        return data

    return {'success': True, 'status_code': None, 'json': data, 'path': path}


def get_file_report(hash_val: str) -> Dict[str, Any]:
    """Offline file metadata (previously saved VT file JSON)."""
    return load_cached_response(hash_val, "file")


def get_behavior_report(hash_val: str) -> Dict[str, Any]:
    """Offline behaviour summary (previously saved VT behaviour JSON)."""
    return load_cached_response(hash_val, "behavior")

print("✅ Cached response helpers ready")


✅ Cached response helpers ready


In [None]:
def extract_behavioral_indicators(behav_resp: Dict[str, Any]) -> Dict[str, Any]:
    """
    Parse /behaviour_summary response safely into flattened indicators.
    Returns empty lists when fields are missing (some hashes have no behaviour).
    """
    indicators = {
        'processes_created': [],
        'files_written': [],
        'files_deleted': [],
        'registry_keys_set': [],
        'registry_keys_deleted': [],
        'dns_lookups': [],
        'ip_traffic': [],
        'http_conversations': [],
        'command_executions': [],
        'mutexes_created': [],
        'services_created': [],
        'tags': [],
        'mitre_techniques': []
    }

    # If cached call missing or invalid JSON, return empty indicators
    if not behav_resp or not behav_resp.get('success'):
        return indicators

    try:
        j = behav_resp.get('json') or {}
        if not isinstance(j, dict):
            return indicators

        # Expected nesting: json['data']
        attrs = j.get('data') or {}

        if not isinstance(attrs, dict):
            return indicators

        # Copy simple lists
        indicators['processes_created'] = attrs.get('processes_created') or []
        indicators['files_written']     = attrs.get('files_written')     or []
        indicators['files_deleted']     = attrs.get('files_deleted')     or []
        indicators['registry_keys_set'] = attrs.get('registry_keys_set') or []
        indicators['registry_keys_deleted'] = attrs.get('registry_keys_deleted') or []
        indicators['command_executions']    = attrs.get('command_executions')    or []
        indicators['mutexes_created']       = attrs.get('mutexes_created')       or []
        indicators['services_created']      = attrs.get('services_created')      or []

        tags = attrs.get('tags') or []
        if isinstance(tags, list):
            indicators['tags'] = [str(tag).strip() for tag in tags if str(tag).strip()]

        # Normalize structured items
        dns_lookups = attrs.get('dns_lookups') or []
        if isinstance(dns_lookups, list):
            indicators['dns_lookups'] = [d.get('hostname', '') for d in dns_lookups if isinstance(d, dict)]

        ip_traffic = attrs.get('ip_traffic') or []
        if isinstance(ip_traffic, list):
            indicators['ip_traffic'] = [
                f"{ip.get('destination_ip','')}:{ip.get('destination_port','')}"
                for ip in ip_traffic if isinstance(ip, dict)
            ]

        http_conversations = attrs.get('http_conversations') or []
        if isinstance(http_conversations, list):
            indicators['http_conversations'] = [h.get('url','') for h in http_conversations if isinstance(h, dict)]

        mt = attrs.get('mitre_attack_techniques') or []
        if isinstance(mt, list):
            indicators['mitre_techniques'] = [t for t in mt if isinstance(t, dict)]

    except Exception as e:
        logging.error(f"Error parsing behavior response safely: {e}")

    return indicators

print("✅ Behaviour parser ready")


✅ Behaviour parser ready


In [11]:
def build_attck_layer(tech_counter: Counter) -> Dict[str, Any]:
    """
    Build a minimal ATT&CK Navigator layer.
    - technique.score = occurrence count across all samples
    - tactic left blank (behaviour_summary may not map tactics)
    """
    techniques = [
        {"techniqueID": tid, "tactic": "", "score": int(count)}
        for tid, count in tech_counter.items()
    ]
    return {
        "name": "VT → MITRE ATT&CK Layer",
        "description": "Techniques extracted from cached VirusTotal behaviour summaries",
        "domain": "enterprise-attack",
        "version": "4.5",
        "techniques": techniques
    }

print("✅ ATT&CK layer builder ready")


✅ ATT&CK layer builder ready


In [12]:

def collect_samples(sample_df: pd.DataFrame,
                    save_every: int = 10) -> List[Dict[str, Any]]:
    """
    For each hash:
      1) Load cached file report JSON from RAW_DIR
      2) Load cached behaviour summary JSON from RAW_DIR
      3) Parse indicators → row
      4) Persist aggregated outputs (CSV/JSON + ATT&CK layer)
    """
    results = []
    tech_counter = Counter()

    for idx, row in enumerate(tqdm(sample_df.to_dict(orient="records"), desc="Parsing cached samples"), start=1):
        h = row['hash']
        family = row.get('malware_family', 'Unknown')
        source = row.get('source', 'Unknown')

        # Initialize result row with defaults
        result = {
            'hash': h,
            'family': family,
            'source': source,
            'status': 'failed',
            'detection_ratio': None,
            'first_seen': None,
            'last_seen': None,
            'processes_count': 0,
            'files_written_count': 0,
            'files_deleted_count': 0,
            'registry_keys_set_count': 0,
            'dns_lookups_count': 0,
            'ip_connections_count': 0,
            'http_requests_count': 0,
            'mutexes_count': 0,
            'mitre_techniques_count': 0,
            'mitre_techniques': '',
            'tags_count': 0,
            'tags': '',
            'collected_date': datetime.now(timezone.utc).isoformat(),
        }

        # ---- 1) File report (cached)
        fr = get_file_report(h)
        if not fr.get('success'):
            result['error'] = fr.get('error', 'file_report_missing')
            logging.warning(f"Skipping {h}: {result['error']}")
            results.append(result)
            continue

        try:
            attrs = fr.get('json', {}).get('data', {}).get('attributes', {}) if isinstance(fr.get('json'), dict) else {}
            stats = attrs.get('last_analysis_stats', {}) if isinstance(attrs, dict) else {}
            total = sum(stats.values()) if isinstance(stats, dict) else 0
            malicious = stats.get('malicious', 0) if isinstance(stats, dict) else 0
            result['detection_ratio'] = f"{malicious}/{total}" if total else None
            result['first_seen'] = attrs.get('first_submission_date') or None
            result['last_seen'] = attrs.get('last_analysis_date') or None
        except Exception as e:
            logging.debug(f"Error extracting stats for {h}: {e}")

        # ---- 2) Behaviour summary (cached)
        br = get_behavior_report(h)
        if not br.get('success'):
            logging.warning(f"No behaviour summary for {h}: {br.get('error', 'behavior_report_missing')}")

        indicators = extract_behavioral_indicators(br)

        result.update({
            'processes_count': len(indicators['processes_created']),
            'files_written_count': len(indicators['files_written']),
            'files_deleted_count': len(indicators['files_deleted']),
            'registry_keys_set_count': len(indicators['registry_keys_set']),
            'dns_lookups_count': len(indicators['dns_lookups']),
            'ip_connections_count': len(indicators['ip_traffic']),
            'http_requests_count': len(indicators['http_conversations']),
            'mutexes_count': len(indicators['mutexes_created']),
            'mitre_techniques_count': len(indicators['mitre_techniques']),
            'mitre_techniques': ", ".join([t.get('id','') for t in indicators['mitre_techniques'] if isinstance(t, dict)]),
            'tags_count': len(indicators['tags']),
            'tags': ", ".join(indicators['tags']),
        })

        for t in indicators['mitre_techniques']:
            if isinstance(t, dict) and t.get('id'):
                tech_counter[t['id']] += 1

        # Mark row success and append
        result['status'] = 'success'
        results.append(result)

        # Periodic save of partial progress
        if idx % save_every == 0 or idx == len(sample_df):
            df_partial = pd.DataFrame(results)
            df_partial.to_csv(os.path.join(OUTPUT_DIR, "analysis_results_partial.csv"), index=False)
            with open(os.path.join(OUTPUT_DIR, "analysis_results_partial.json"), "w") as fh:
                json.dump(results, fh, indent=2)
            logging.info(f"Saved partial results at {idx}/{len(sample_df)}")

    # Save final, full outputs
    df_final = pd.DataFrame(results)
    df_final.to_csv(os.path.join(OUTPUT_DIR, "analysis_results_offline.csv"), index=False)
    df_final.to_json(os.path.join(OUTPUT_DIR, "analysis_results_offline.json"), orient="records", indent=2)

    # Build ATT&CK Navigator layer
    layer_json = build_attck_layer(tech_counter)
    with open(os.path.join(OUTPUT_DIR, "attack_navigator_layer.json"), "w") as fh:
        json.dump(layer_json, fh, indent=2)

    logging.info("Offline collection complete. Outputs saved to ./output")
    return results

print("✅ Collector ready (offline)")


✅ Collector ready (offline)


In [13]:
def main(input_csv: str = INPUT_CSV,
         min_samples: int = TARGET_MIN_SAMPLES,
         max_samples: int = TARGET_MAX_SAMPLES,
         min_families: int = MIN_FAMILIES):
    """Offline entry point that operates entirely on cached JSON files."""
    logging.info("Reading input CSV...")
    df = read_hash_csv(input_csv)
    if df.empty:
        raise SystemExit("No hashes found in input CSV. Exiting.")

    assert_dataset_meets_requirements(df, min_samples=min_samples, min_families=min_families)

    samples = select_sample_hashes(df, min_samples=min_samples, max_samples=max_samples, min_families=min_families)
    logging.info(f"Beginning offline aggregation for {len(samples)} samples using cached responses in {RAW_DIR}")

    results = collect_samples(samples)

    df_res = pd.DataFrame(results)
    success_count = (df_res['status'] == 'success').sum()
    logging.info(f"Finished. Success: {success_count}/{len(df_res)}")
    logging.info("Open output/attack_navigator_layer.json in ATT&CK Navigator to visualize technique coverage.")

print("✅ Main ready (offline)")


✅ Main ready (offline)


In [8]:
import json
import logging
import re
from pathlib import Path
from collections import defaultdict
from typing import Dict, Any, List, Tuple

def generate_mitre_mapping(results_path: str = "./output/analysis_results_offline.json",
                            attack_bundle_path: str = "./enterprise-attack.json",
                            mapping_output_path: str = "./output/mitre_technique_mapping.json") -> Dict[str, Any]:
    """Map deprecated MITRE technique IDs to their replacements using cached offline results."""
    results_file = Path(results_path)
    attack_file = Path(attack_bundle_path)
    if not results_file.exists():
        raise FileNotFoundError(f"Analysis file not found: {results_file}")
    if not attack_file.exists():
        raise FileNotFoundError(f"ATT&CK bundle not found: {attack_file}")

    def load_json_payload(path: Path) -> Any:
        with path.open('r', encoding='utf-8') as handle:
            return json.load(handle)

    def split_technique_list(raw: str) -> List[str]:
        if not raw:
            return []
        tokens = re.split(r'[|,]+', raw)
        return [token.strip() for token in tokens if token and token.strip()]

    def build_attack_maps(attack_stix: Dict[str, Any]) -> Tuple[Dict[str, str], Dict[str, str], Dict[str, List[Tuple[str, str]]]]:
        extid_to_objid: Dict[str, str] = {}
        objid_to_extid: Dict[str, str] = {}
        rel_from_to: Dict[str, List[Tuple[str, str]]] = defaultdict(list)
        for obj in attack_stix.get('objects', []):
            if obj.get('type') != 'attack-pattern':
                continue
            external_id = None
            for ref in obj.get('external_references', []):
                if ref.get('source_name') == 'mitre-attack' and ref.get('external_id', '').startswith('T'):
                    external_id = ref['external_id']
                    break
            if not external_id:
                continue
            obj_id = obj.get('id')
            extid_to_objid[external_id] = obj_id
            objid_to_extid[obj_id] = external_id
        for obj in attack_stix.get('objects', []):
            if obj.get('type') != 'relationship':
                continue
            rtype = (obj.get('relationship_type') or '').lower()
            if not rtype or not any(keyword in rtype for keyword in ('replace', 'revoked', 'duplicate')):
                continue
            src = obj.get('source_ref')
            tgt = obj.get('target_ref')
            if src and tgt:
                rel_from_to[src].append((tgt, rtype))
        return extid_to_objid, objid_to_extid, rel_from_to

    def find_replacement_extids(extid: str,
                                extid_to_objid: Dict[str, str],
                                objid_to_extid: Dict[str, str],
                                rel_from_to: Dict[str, List[Tuple[str, str]]]) -> List[Tuple[str, str]]:
        replacements: List[Tuple[str, str]] = []
        obj_id = extid_to_objid.get(extid)
        if not obj_id:
            return replacements
        for tgt_oid, rtype in rel_from_to.get(obj_id, []):
            tgt_ext = objid_to_extid.get(tgt_oid)
            if tgt_ext:
                replacements.append((tgt_ext, rtype))
        for src_oid, targets in rel_from_to.items():
            for tgt_oid, rtype in targets:
                if tgt_oid == obj_id:
                    src_ext = objid_to_extid.get(src_oid)
                    if src_ext and (src_ext, rtype) not in replacements:
                        replacements.append((src_ext, rtype))
        seen: set[str] = set()
        ordered: List[Tuple[str, str]] = []
        for ext, rel in replacements:
            if ext not in seen:
                ordered.append((ext, rel))
                seen.add(ext)
        return ordered

    results_payload = load_json_payload(results_file)
    attack_bundle = load_json_payload(attack_file)
    extid_to_objid, objid_to_extid, rel_from_to = build_attack_maps(attack_bundle)

    unique_techniques: set[str] = set()
    for record in results_payload:
        unique_techniques.update(split_technique_list(record.get('mitre_techniques', '')))

    mapping: Dict[str, Dict[str, Any]] = {}
    for tid in sorted(unique_techniques):
        replacements = find_replacement_extids(tid, extid_to_objid, objid_to_extid, rel_from_to)
        if not replacements:
            continue
        primary = replacements[0][0]
        if primary == tid:
            continue
        mapping[tid] = {
            'replacement': primary,
            'relationships': [
                {'technique_id': rep_id, 'relationship': rel or ''}
                for rep_id, rel in replacements
            ]
        }

    output_path = Path(mapping_output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with output_path.open('w', encoding='utf-8') as handle:
        json.dump(mapping, handle, indent=2)

    logging.info("Identified %d unique MITRE techniques (%d replacements)", len(unique_techniques), len(mapping))
    logging.info("Saved mapping → %s", output_path)
    return mapping

print("✅ MITRE mapping helper ready")


✅ MITRE mapping helper ready


In [13]:
def apply_mitre_mapping_to_offline_results(mapping_path: str = "./output/mitre_technique_mapping.json",
                                            json_path: str = "./output/analysis_results_offline.json",
                                            csv_path: str = "./output/analysis_results_offline.csv") -> pd.DataFrame:
    """Copy current MITRE IDs to `mitre_technique_old` and append replacements alongside existing IDs."""
    mapping_file = Path(mapping_path)
    json_file = Path(json_path)
    csv_file = Path(csv_path)

    if not mapping_file.exists():
        raise FileNotFoundError(f"Mapping file not found: {mapping_file}")
    if not json_file.exists():
        raise FileNotFoundError(f"Offline JSON results not found: {json_file}")
    if not csv_file.exists():
        raise FileNotFoundError(f"Offline CSV results not found: {csv_file}")

    with mapping_file.open('r', encoding='utf-8') as handle:
        mapping = json.load(handle)

    with json_file.open('r', encoding='utf-8') as handle:
        records = json.load(handle)

    def split_techniques(raw: str) -> List[str]:
        if not raw:
            return []
        return [tid.strip() for tid in raw.split(',') if tid.strip()]

    def join_unique(values: List[str]) -> str:
        ordered: List[str] = []
        seen: set[str] = set()
        for tid in values:
            if tid and tid not in seen:
                ordered.append(tid)
                seen.add(tid)
        return ', '.join(ordered)

    updated_records: List[Dict[str, Any]] = []
    replacements_applied = 0

    for record in records:
        original = record.get('mitre_techniques', '') or ''
        techniques = split_techniques(original)
        combined: List[str] = []
        for tid in techniques:
            combined.append(tid)
            remap = mapping.get(tid)
            if remap:
                new_tid = remap['replacement']
                combined.append(new_tid)
                replacements_applied += 1
        record['mitre_technique_old'] = original
        record['mitre_techniques'] = join_unique(combined)
        record['mitre_techniques_count'] = len(split_techniques(record['mitre_techniques']))
        updated_records.append(record)

    with json_file.open('w', encoding='utf-8') as handle:
        json.dump(updated_records, handle, indent=2)

    df = pd.DataFrame(updated_records)
    df.to_csv(csv_file, index=False)

    logging.info("Updated %d records; appended %d replacement techniques", len(updated_records), replacements_applied)
    logging.info("JSON updated → %s", json_file)
    logging.info("CSV updated → %s", csv_file)
    return df

print("✅ MITRE mapping application helper ready")

✅ MITRE mapping application helper ready


In [14]:
# Run the offline flow when you're ready.
main()


2025-11-10 14:30:48,106 - INFO - Reading input CSV...
2025-11-10 14:30:48,122 - INFO - Available families in input: 23
2025-11-10 14:30:48,122 - INFO - Target sample count: 100
2025-11-10 14:30:48,139 - INFO - Selected 100 samples from 23 families
2025-11-10 14:30:48,140 - INFO - Beginning offline aggregation for 100 samples using cached responses in ./output/raw_responses
2025-11-10 14:30:48,173 - INFO - Saved partial results at 10/100
2025-11-10 14:30:48,194 - INFO - Saved partial results at 20/100
2025-11-10 14:30:48,209 - INFO - Saved partial results at 30/100
2025-11-10 14:30:48,226 - INFO - Saved partial results at 40/100
2025-11-10 14:30:48,244 - INFO - Saved partial results at 50/100
Parsing cached samples:  56%|██████████████████████████████████████████████████████████████████████████████████████████████████                                                                             | 56/100 [00:00<00:00, 549.25it/s]2025-11-10 14:30:48,263 - INFO - Saved partial results at 60/

In [16]:
import json
import re
import requests
import os  # <-- New import
from pathlib import Path
from collections import defaultdict
from stix2 import MemoryStore
from stix2 import Filter  # Need this for stix2 queries

# --- Configuration ---
output_dir = Path('output')
output_dir.mkdir(parents=True, exist_ok=True)
INPUT_FILE = output_dir / 'analysis_results_offline.json'
OUTPUT_FILE = output_dir / 'attack_layer_offline.json'
ATTACK_STIX_URL = "https://raw.githubusercontent.com/mitre/cti/master/enterprise-attack/enterprise-attack.json"
LOCAL_STIX_FILE = "enterprise-attack.json"  # <-- New setting
LAYER_NAME = "Malware Behavioral Layer (Authoritative Mapping)"
MAX_SCORE_CAP = 100

def parse_technique_ids(raw_value):
    """Split MITRE technique strings regardless of the delimiter used."""
    if not raw_value:
        return []

    if isinstance(raw_value, list):
        tokens = raw_value
    else:
        tokens = re.split(r'[|,]+', raw_value)

    cleaned = []
    seen = set()
    for token in tokens:
        tid = (token or '').strip()
        if not tid or tid in seen:
            continue
        cleaned.append(tid)
        seen.add(tid)
    return cleaned

def fetch_mitre_stix_data(url, local_file):
    """Fetches the latest MITRE ATT&CK STIX data, prioritizing local file load."""
    
    # 1. Try to load from a local file first
    if os.path.exists(local_file):
        print(f"Loading MITRE ATT&CK data from local file: {local_file}...")
        try:
            with open(local_file, 'r', encoding='utf-8') as f:
                stix_data = json.load(f)
            # stix2.MemoryStore expects the 'objects' list
            return MemoryStore(stix_data=stix_data.get("objects", []))
        except Exception as e:
            print(f"Error loading local file '{local_file}': {e}")
    
    # 2. Fallback to network request if local file is not found or fails
    print("Local file not found or failed to load. Attempting to fetch data over network...")
    try:
        # Increase timeout to 60 seconds for the large file
        response = requests.get(url, timeout=60) 
        response.raise_for_status() 
        stix_data = response.json()
        print("Successfully downloaded MITRE ATT&CK data.")
        return MemoryStore(stix_data=stix_data.get("objects", []))
    except requests.RequestException as e:
        print(f"Error fetching MITRE data from URL: {e}")
        print("Please ensure you have a stable network connection.")
        return None

def build_accurate_tactic_map(stix_store):
    """
    Builds a dictionary {technique_id: [tactic_names]} using the STIX relationships.
    """
    if not stix_store:
        return {}
        
    tactic_map = defaultdict(list)
    
    # 1. Get all Techniques and Sub-techniques (type='attack-pattern')
    techniques = stix_store.query(Filter("type", "=", "attack-pattern"))
    
    # 2. Map Techniques to Tactics
    for tech in techniques:
        # Get the primary ATT&CK ID (e.g., T1560)
        tech_id = tech.external_references[0]['external_id']
        
        # Techniques can belong to multiple tactics
        if hasattr(tech, 'kill_chain_phases'):
            for phase in tech.kill_chain_phases:
                # The 'phase_name' corresponds to the Tactic's short name (e.g., initial-access)
                tactic_name = phase['phase_name'].lower().replace(' ', '-')
                tactic_map[tech_id].append(tactic_name)
    
    return {k: sorted(list(set(v))) for k, v in tactic_map.items()}

# ------------------------------------

def generate_attack_layer(input_file, output_file):
    """
    Reads the analysis file, aggregates MITRE techniques using official mapping, 
    and generates the layer JSON.
    """
    # 0. Set up the authoritative mapping
    stix_store = fetch_mitre_stix_data(ATTACK_STIX_URL, LOCAL_STIX_FILE)
    if not stix_store:
        print("Cannot proceed without MITRE ATT&CK data. Exiting.")
        return
        
    TECHNIQUE_TO_TACTIC_MAP = build_accurate_tactic_map(stix_store)
    
    # 1. Load your analysis data
    try:
        with open(input_file, 'r', encoding='utf-8') as f:
            analysis_data = json.load(f)
    except Exception as e:
        print(f"Error loading input file: {e}")
        return

    # Dictionary to store technique ID -> set of families that used it
    technique_family_map = defaultdict(set)
    
    # 2. Aggregate Techniques and Families (Same as before)
    for sample in analysis_data:
        family = sample.get("family", "unknown_family")
        techniques = parse_technique_ids(sample.get("mitre_techniques"))
        
        for technique_id in techniques:
            technique_family_map[technique_id].add(family)

    if not technique_family_map:
        print("No MITRE techniques were found in the offline analysis file.")
        return

    # 3. Construct Layer Data using the AUTHORITATIVE MAP
    max_family_count = 0
    layer_techniques = []

    for technique_id, families in sorted(technique_family_map.items()):
        family_count = len(families)
        max_family_count = max(max_family_count, family_count)
        
        # Get Tactic(s) from the official map
        official_tactics = TECHNIQUE_TO_TACTIC_MAP.get(technique_id, ["unknown-tactic"])
        
        # Use the first tactic for the primary field
        primary_tactic = official_tactics[0] 
        
        # Simple color assignment based on known tactics
        color = "#6c0bf4"  # Default color (gray)
        if "execution" in official_tactics:
            color = "#ff6666"  # Red
        elif "persistence" in official_tactics:
            color = "#cc99ff"  # Purple
        elif "defense-evasion" in official_tactics:
            color = "#ffcc66"  # Yellow/Orange

        layer_techniques.append({
            "techniqueID": technique_id,
            "tactic": primary_tactic,
            "score": family_count,
            "color": color,
            "metadata": [
                {"name": "Family Count", "value": str(family_count)},
                {"name": "Official Tactics", "value": ", ".join(t.replace('-', ' ').title() for t in official_tactics)}
            ],
            "comment": f"Used by {family_count} families. Official Tactic(s): {', '.join(t.replace('-', ' ').title() for t in official_tactics)}."
        })

    # 4. Final Layer Construction
    layer_json = {
        "name": LAYER_NAME,
        "description": f"Aggregated Techniques from {len(analysis_data)} samples.",
        "domain": "enterprise-attack",
        "version": "4.5",
        "techniques": layer_techniques,
        "gradient": {
            "colors": ["#ffffff", "#ff0000"],
            "minValue": 0,
            "maxValue": MAX_SCORE_CAP if max_family_count < MAX_SCORE_CAP else max_family_count + 5 
        },
        "hideDisabled": False,
        "showTacticRowBackground": True
    }
    
    # 5. Save the Layer JSON
    try:
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(layer_json, f, indent=2)
        print(f"\nSuccessfully generated ATT&CK Layer file: '{output_file}' with authoritative mapping.")
    except Exception as e:
        print(f"Error saving file: {e}")

if __name__ == '__main__':
    generate_attack_layer(INPUT_FILE, OUTPUT_FILE)


Loading MITRE ATT&CK data from local file: enterprise-attack.json...

Successfully generated ATT&CK Layer file: 'output/attack_layer_offline.json' with authoritative mapping.
