In [None]:
from collections import defaultdict
import json
from mitreattack.stix20 import MitreAttackData
import nest_asyncio
from ratelimit import limits, sleep_and_retry
import sqlite3
import time
import vt

In [None]:
nest_asyncio.apply()

In [None]:
VT_API_KEY = "<VT_API_KEY>"

In [None]:
# Load mitre attack data
mitre_attack_data = MitreAttackData("enterprise-attack.json")

In [None]:
def build_mitre_map():
    def enrich_external_ids(elems):
        return {r.external_id: {"url": r.url, "name": e.name} for e in elems for r in e.external_references if r.source_name == 'mitre-attack'}
    
    def enrich_names(elems):
        return {e.name: {"url": r.url} for e in elems for r in e.external_references if r.source_name == 'mitre-attack'}
    
    tactics = mitre_attack_data.get_tactics()
    techniques = mitre_attack_data.get_techniques()
    groups = mitre_attack_data.get_groups()
    software = mitre_attack_data.get_software()
    campaigns = mitre_attack_data.get_campaigns()
    datasources = mitre_attack_data.get_datasources()
    
    return {
        "tactics": enrich_external_ids(tactics),
        "techniques": enrich_external_ids(techniques),
        "group_names": enrich_names(groups),
        "group_ids": enrich_external_ids(groups),
        "software_ids": enrich_external_ids(software),
        "campaign_names": enrich_names(campaigns),
        "campaign_ids": enrich_external_ids(campaigns),
        "datasources": enrich_external_ids(datasources),
    }

mitre_map = build_mitre_map()

In [None]:
def enrich_mitre(mitre):
    res = dict()
    for key, values in mitre.items():
        enriched = mitre_map[key]
        res[key] = {v: enriched[v] for v in values if v in enriched}

    return res

In [None]:
@sleep_and_retry
@limits(calls=1, period=20) # VT restricts to 4 QPM
def enrich_hashes(vt_client, filehashes):
    res = defaultdict(dict)
    for fh in filehashes:
        try:
            file = vt_client.get_object(f"/files/{fh}")
            res[fh]["VirusTotal"] = f"https://www.virustotal.com/gui/file/{fh}"
        except:
            pass

    return res

In [None]:
def split(iocs):
    return iocs.split(",") if iocs else []

def enrich_reports(sq_conn):
    # Initialize database
    with open("sqlite_schema.sql") as f:
        try:
            with sq_conn:
                sq_conn.executescript(f.read())
        except Exception as e:
            print(f"error applying schema: {e}")

    read_cur = sq_conn.cursor()
    update_cur = sq_conn.cursor()

    vt_client = vt.Client(apikey=VT_API_KEY)

    count = 0
    for row in read_cur.execute("SELECT id, md5s, sha1s, sha256s, mitre FROM report"):
        report_id, md5s, sha1s, sha256s, mitre = row
        md5s = split(md5s)
        sha1s = split(sha1s)
        sha256s = split(sha256s)
        mitre = json.loads(mitre)

        enriched_mitre = enrich_mitre(mitre)

        enriched_md5s = {}
        if len(md5s) > 0:
            enriched_md5s = enrich_hashes(vt_client, md5s)

        enriched_sha1s = {}
        if len(sha1s) > 0:
            enriched_sha1s = enrich_hashes(vt_client, sha1s)

        enriched_sha256s = {}
        if len(sha256s) > 0:
            enriched_sha256s = enrich_hashes(vt_client, sha256s)

        data = {
            "report_id": report_id,
            "mitre": json.dumps(enriched_mitre),
            "md5s": json.dumps(enriched_md5s),
            "sha1s": json.dumps(enriched_sha1s),
            "sha256s": json.dumps(enriched_sha256s),
        }

        with sq_conn:
            update_cur.execute("""
                INSERT OR REPLACE INTO enriched_report (
                    report_id,
                    mitre,
                    md5s,
                    sha1s,
                    sha256s
                ) VALUES(
                    :report_id,
                    :mitre,
                    :md5s,
                    :sha1s,
                    :sha256s
                )
            """, data)
            if update_cur.rowcount != 1:
                raise Exception(f"error inserting enriched mitre for {report_id}. rowcount is {update_cur.rowcount}")
        
        count += 1
        print(f"Updated report {report_id} ({count})")

    print(f"Updated {count} rows")

In [None]:
if __name__ == '__main__':
    # Connect to sqlite database
    sq_conn = sqlite3.connect("reports.db")
    try:
        enrich_reports(sq_conn)
    finally:
        sq_conn.close()