In [0]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
import torch
from torch_geometric.data import Data
import os
import torch.nn.functional as F
import json
import warnings
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
warnings.filterwarnings('ignore')
from torch_geometric.loader import NeighborLoader
import multiprocessing

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
%matplotlib inline

# Flash Evaluation on DARPA E3 Theia Dataset: 

This notebook is specifically designed for the evaluation of Flash on the DARPA E3 Theia dataset. Notably, the Theia dataset is characterized as a node-level dataset. In our analysis, Flash is configured to operate in a node-level setting to aptly assess this dataset. A key aspect to note is that the Theia dataset lacks certain essential node attributes for specific node types. This limitation means that Flash cannot be operated in a decoupled mode with offline GNN embeddings for this dataset. Consequently, we employ an online GNN coupled with word2vec semantic embeddings to achieve effective evaluation results for this dataset.

## Dataset Access: 
- Access the Theia dataset via the following link: [Theia Dataset](https://drive.google.com/drive/folders/1QlbUFWAGq3Hpl8wVdzOdIoZLFxkII4EK).
- The dataset files will be downloaded automatically by the script.

## Data Parsing and Execution:
- The script is designed to automatically parse the downloaded data files.
- Execute all cells within this notebook to obtain the evaluation results.

## Model Training and Execution Flexibility:
- The notebook is configured to use pre-trained model weights by default.
- It also provides the option to set parameters for independently training Graph Neural Networks (GNNs) and word2vec models.
- These newly trained models can then be utilized for a comprehensive evaluation of the dataset.

Adhere to these steps for a detailed and effective analysis of the Theia dataset using Flash.


In [0]:
import gdown
urls = ["https://drive.google.com/file/d/10cecNtR3VsHfV0N-gNEeoVeB89kCnse5/view?usp=drive_link",
        "https://drive.google.com/file/d/1Kadc6CUTb4opVSDE4x6RFFnEy0P1cRp0/view?usp=drive_link"]
for url in urls:
    gdown.download(url, quiet=False, use_cookies=False, fuzzy=True)

In [0]:
Train = False

In [None]:
from pprint import pprint
import gzip
from sklearn.manifold import TSNE
import json
import copy
import os

In [None]:
import re

def extract_uuid(line):
    pattern_uuid = re.compile(r'uuid\":\"(.*?)\"')
    return pattern_uuid.findall(line)

def extract_subject_type(line):
    pattern_type = re.compile(r'type\":\"(.*?)\"')
    return pattern_type.findall(line)

def show(file_path):
    print(f"Processing {file_path}")

def extract_edge_info(line):
    pattern_src = re.compile(r'subject\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}')
    pattern_dst1 = re.compile(r'predicateObject\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}')
    pattern_dst2 = re.compile(r'predicateObject2\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}')
    pattern_type = re.compile(r'type\":\"(.*?)\"')
    pattern_time = re.compile(r'timestampNanos\":(.*?),')

    edge_type = extract_subject_type(line)[0]
    timestamp = pattern_time.findall(line)[0]
    src_id = pattern_src.findall(line)

    if len(src_id) == 0:
        return None, None, None, None, None

    src_id = src_id[0]
    dst_id1 = pattern_dst1.findall(line)
    dst_id2 = pattern_dst2.findall(line)

    if len(dst_id1) > 0 and dst_id1[0] != 'null':
        dst_id1 = dst_id1[0]
    else:
        dst_id1 = None

    if len(dst_id2) > 0 and dst_id2[0] != 'null':
        dst_id2 = dst_id2[0]
    else:
        dst_id2 = None

    return src_id, edge_type, timestamp, dst_id1, dst_id2

def process_data(file_path):
    id_nodetype_map = {}
    notice_num = 1000000
    for i in range(100):
        now_path = file_path + '.' + str(i)
        if i == 0:
            now_path = file_path
        if not os.path.exists(now_path):
            break

        with open(now_path, 'r') as f:
            show(now_path)
            cnt = 0
            for line in f:
                cnt += 1
                if cnt % notice_num == 0:
                    print(cnt)

                if 'com.bbn.tc.schema.avro.cdm18.Event' in line or 'com.bbn.tc.schema.avro.cdm18.Host' in line:
                    continue

                if 'com.bbn.tc.schema.avro.cdm18.TimeMarker' in line or 'com.bbn.tc.schema.avro.cdm18.StartMarker' in line:
                    continue

                if 'com.bbn.tc.schema.avro.cdm18.UnitDependency' in line or 'com.bbn.tc.schema.avro.cdm18.EndMarker' in line:
                    continue

                uuid = extract_uuid(line)[0]
                subject_type = extract_subject_type(line)

                if len(subject_type) < 1:
                    if 'com.bbn.tc.schema.avro.cdm18.MemoryObject' in line:
                        id_nodetype_map[uuid] = 'MemoryObject'
                        continue
                    if 'com.bbn.tc.schema.avro.cdm18.NetFlowObject' in line:
                        id_nodetype_map[uuid] = 'NetFlowObject'
                        continue
                    if 'com.bbn.tc.schema.avro.cdm18.UnnamedPipeObject' in line:
                        id_nodetype_map[uuid] = 'UnnamedPipeObject'
                        continue

                id_nodetype_map[uuid] = subject_type[0]

    return id_nodetype_map

def process_edges(file_path, id_nodetype_map):
    notice_num = 1000000
    not_in_cnt = 0

    for i in range(100):
        now_path = file_path + '.' + str(i)
        if i == 0:
            now_path = file_path
        if not os.path.exists(now_path):
            break

        with open(now_path, 'r') as f, open(now_path+'.txt', 'w') as fw:
            cnt = 0
            for line in f:
                cnt += 1
                if cnt % notice_num == 0:
                    print(cnt)

                if 'com.bbn.tc.schema.avro.cdm18.Event' in line:
                    src_id, edge_type, timestamp, dst_id1, dst_id2 = extract_edge_info(line)

                    if src_id is None or src_id not in id_nodetype_map:
                        not_in_cnt += 1
                        continue

                    src_type = id_nodetype_map[src_id]

                    if dst_id1 is not None and dst_id1 in id_nodetype_map:
                        dst_type1 = id_nodetype_map[dst_id1]
                        this_edge1 = f"{src_id}\t{src_type}\t{dst_id1}\t{dst_type1}\t{edge_type}\t{timestamp}\n"
                        fw.write(this_edge1)

                    if dst_id2 is not None and dst_id2 in id_nodetype_map:
                        dst_type2 = id_nodetype_map[dst_id2]
                        this_edge2 = f"{src_id}\t{src_type}\t{dst_id2}\t{dst_type2}\t{edge_type}\t{timestamp}\n"
                        fw.write(this_edge2)

def run_data_processing():
    os.system('tar -zxvf ta1-theia-e3-official-1r.json.tar.gz')
    os.system('tar -zxvf ta1-theia-e3-official-6r.json.tar.gz')
    
    path_list = ['ta1-theia-e3-official-1r.json', 'ta1-theia-e3-official-6r.json']

    for path in path_list:
        id_nodetype_map = process_data(path)
        process_edges(path, id_nodetype_map)

    os.system('cp ta1-theia-e3-official-1r.json.txt theia_train.txt')
    os.system('cp ta1-theia-e3-official-6r.json.8.txt theia_test.txt')

In [None]:
run_data_processing()

In [None]:
def add_node_properties(nodes, node_id, properties):
    if node_id not in nodes:
        nodes[node_id] = []
    nodes[node_id].extend(properties)

def update_edge_index(edges, edge_index, index):
    for src_id, dst_id in edges:
        src = index[src_id]
        dst = index[dst_id]
        edge_index[0].append(src)
        edge_index[1].append(dst)

def prepare_graph(df):
    nodes, labels, edges = {}, {}, []
    dummies = {"SUBJECT_PROCESS":	0, "MemoryObject":	1, "FILE_OBJECT_BLOCK":	2,
               "NetFlowObject":	3,"PRINCIPAL_REMOTE":	4,'PRINCIPAL_LOCAL':5}

    for _, row in df.iterrows():
        action = row["action"]
        properties = [row['exec'], action] + ([row['path']] if row['path'] else [])
        
        actor_id = row["actorID"]
        add_node_properties(nodes, actor_id, properties)
        labels[actor_id] = dummies[row['actor_type']]

        object_id = row["objectID"]
        add_node_properties(nodes, object_id, properties)
        labels[object_id] = dummies[row['object']]

        edges.append((actor_id, object_id))

    features, feat_labels, edge_index, index_map = [], [], [[], []], {}
    for node_id, props in nodes.items():
        features.append(props)
        feat_labels.append(labels[node_id])
        index_map[node_id] = len(features) - 1

    update_edge_index(edges, edge_index, index_map)

    return features, feat_labels, edge_index, list(index_map.keys())

In [None]:
from torch_geometric.nn import GCNConv
from torch_geometric.nn import SAGEConv, GATConv
import torch.nn as nn

class GCN(torch.nn.Module):
    def __init__(self,in_channel,out_channel):
        super().__init__()
        self.conv1 = SAGEConv(in_channel, 32, normalize=True)
        self.conv2 = SAGEConv(32, out_channel, normalize=True)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)

        x = self.conv2(x, edge_index)
        return x

In [None]:
def visualize(h, color):
    z = TSNE(n_components=2).fit_transform(h.detach().cpu().numpy())

    plt.figure(figsize=(10,10))
    plt.xticks([])
    plt.yticks([])

    plt.scatter(z[:, 0], z[:, 1], s=70, c=color, cmap="Set2")
    plt.show()

In [None]:
from gensim.models.callbacks import CallbackAny2Vec
import gensim
from gensim.models import Word2Vec
from multiprocessing import Pool
from itertools import compress
from tqdm import tqdm
import time

class EpochSaver(CallbackAny2Vec):

    def __init__(self):
        self.epoch = 0

    def on_epoch_end(self, model):
        model.save('word2vec_theia_E3.model')
        self.epoch += 1

In [None]:
class EpochLogger(CallbackAny2Vec):

    def __init__(self):
        self.epoch = 0

    def on_epoch_begin(self, model):
        print("Epoch #{} start".format(self.epoch))

    def on_epoch_end(self, model):
        print("Epoch #{} end".format(self.epoch))
        self.epoch += 1

In [None]:
logger = EpochLogger()
saver = EpochSaver()

In [None]:
def add_attributes(d,p):
    
    f = open(p)
    data = [json.loads(x) for x in f if "EVENT" in x]

    info = []
    for x in data:
        try:
            action = x['datum']['com.bbn.tc.schema.avro.cdm18.Event']['type']
        except:
            action = ''
        try:
            actor = x['datum']['com.bbn.tc.schema.avro.cdm18.Event']['subject']['com.bbn.tc.schema.avro.cdm18.UUID']
        except:
            actor = ''
        try:
            obj = x['datum']['com.bbn.tc.schema.avro.cdm18.Event']['predicateObject']['com.bbn.tc.schema.avro.cdm18.UUID']
        except:
            obj = ''
        try:
            timestamp = x['datum']['com.bbn.tc.schema.avro.cdm18.Event']['timestampNanos']
        except:
            timestamp = ''
        try:
            cmd = x['datum']['com.bbn.tc.schema.avro.cdm18.Event']['properties']['map']['cmdLine']
        except:
            cmd = ''
        try:
            path = x['datum']['com.bbn.tc.schema.avro.cdm18.Event']['predicateObjectPath']['string']
        except:
            path = ''
        try:
            path2 = x['datum']['com.bbn.tc.schema.avro.cdm18.Event']['predicateObject2Path']['string']
        except:
            path2 = ''
        try:
            obj2 = x['datum']['com.bbn.tc.schema.avro.cdm18.Event']['predicateObject2']['com.bbn.tc.schema.avro.cdm18.UUID']
            info.append({'actorID':actor,'objectID':obj2,'action':action,'timestamp':timestamp,'exec':cmd, 'path':path2})
        except:
            pass

        info.append({'actorID':actor,'objectID':obj,'action':action,'timestamp':timestamp,'exec':cmd, 'path':path})

    rdf = pd.DataFrame.from_records(info).astype(str)
    d = d.astype(str)

    return d.merge(rdf,how='inner',on=['actorID','objectID','action','timestamp']).drop_duplicates()

In [None]:
if Train:
    f = open("theia_train.txt")
    data = f.read().split('\n')
    data = [line.split('\t') for line in data]
    df = pd.DataFrame (data, columns = ['actorID', 'actor_type','objectID','object','action','timestamp'])
    df = df.dropna()
    df.sort_values(by='timestamp', ascending=True,inplace=True)
    df = add_attributes(df,"ta1-theia-e3-official-1r.json")
    phrases,labels,edges,mapp = prepare_graph(df)

In [None]:
from sklearn.utils import class_weight
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss

model = GCN(30,5).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

In [None]:
if Train:
    word2vec = Word2Vec(sentences=phrases, vector_size=30, window=5, min_count=1, workers=8,epochs=300,callbacks=[saver,logger])

In [None]:
import math
import torch
import numpy as np
from gensim.models import Word2Vec

class PositionalEncoder:

    def __init__(self, d_model, max_len=100000):
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        self.pe = torch.zeros(max_len, d_model)
        self.pe[:, 0::2] = torch.sin(position * div_term)
        self.pe[:, 1::2] = torch.cos(position * div_term)

    def embed(self, x):
        return x + self.pe[:x.size(0)]

def infer(document):
    word_embeddings = [w2vmodel.wv[word] for word in document if word in  w2vmodel.wv]
    
    if not word_embeddings:
        return np.zeros(20)

    output_embedding = torch.tensor(word_embeddings, dtype=torch.float)
    if len(document) < 100000:
        output_embedding = encoder.embed(output_embedding)

    output_embedding = output_embedding.detach().cpu().numpy()
    return np.mean(output_embedding, axis=0)

encoder = PositionalEncoder(30)
w2vmodel = Word2Vec.load("trained_weights/theia/word2vec_theia_E3.model")

In [None]:
from torch_geometric import utils

if Train:
    l = np.array(labels)
    class_weights = class_weight.compute_class_weight(class_weight = "balanced",classes = np.unique(l),y = l)
    class_weights = torch.tensor(class_weights,dtype=torch.float).to(device)
    criterion = CrossEntropyLoss(weight=class_weights,reduction='mean')

    nodes = [infer(x) for x in phrases]
    nodes = np.array(nodes)  

    graph = Data(x=torch.tensor(nodes,dtype=torch.float).to(device),y=torch.tensor(labels,dtype=torch.long).to(device), edge_index=torch.tensor(edges,dtype=torch.long).to(device))
    graph.n_id = torch.arange(graph.num_nodes)
    mask = torch.tensor([True]*graph.num_nodes, dtype=torch.bool)

    for m_n in range(20):

      loader = NeighborLoader(graph, num_neighbors=[-1,-1], batch_size=5000,input_nodes=mask)
      total_loss = 0
      for subg in loader:
          model.train()
          optimizer.zero_grad() 
          out = model(subg.x, subg.edge_index) 
          loss = criterion(out, subg.y) 
          loss.backward() 
          optimizer.step()      
          total_loss += loss.item() * subg.batch_size
      print(total_loss / mask.sum().item())

      loader = NeighborLoader(graph, num_neighbors=[-1,-1], batch_size=5000,input_nodes=mask)
      for subg in loader:
          model.eval()
          out = model(subg.x, subg.edge_index)

          sorted, indices = out.sort(dim=1,descending=True)
          conf = (sorted[:,0] - sorted[:,1]) / sorted[:,0]
          conf = (conf - conf.min()) / conf.max()

          pred = indices[:,0]
          cond = (pred == subg.y) | (conf >= 0.9)
          mask[subg.n_id[cond]] = False

      torch.save(model.state_dict(), f'lword2vec_gnn_theia{m_n}_E3.pth')
      print(f'Model# {m_n}. {mask.sum().item()} nodes still misclassified \n')

In [None]:
from itertools import compress
from torch_geometric import utils

def Get_Adjacent(ids, mapp, edges, hops):
    if hops == 0:
        return set()
    
    neighbors = set()
    for edge in zip(edges[0], edges[1]):
        if any(mapp[node] in ids for node in edge):
            neighbors.update(mapp[node] for node in edge)

    if hops > 1:
        neighbors = neighbors.union(Get_Adjacent(neighbors, mapp, edges, hops - 1))
    
    return neighbors

def calculate_metrics(TP, FP, FN, TN):
    FPR = FP / (FP + TN) if FP + TN > 0 else 0
    TPR = TP / (TP + FN) if TP + FN > 0 else 0

    prec = TP / (TP + FP) if TP + FP > 0 else 0
    rec = TP / (TP + FN) if TP + FN > 0 else 0
    fscore = (2 * prec * rec) / (prec + rec) if prec + rec > 0 else 0

    return prec, rec, fscore, FPR, TPR

def helper(MP, all_pids, GP, edges, mapp):
    TP = MP.intersection(GP)
    FP = MP - GP
    FN = GP - MP
    TN = all_pids - (GP | MP)

    two_hop_gp = Get_Adjacent(GP, mapp, edges, 2)
    two_hop_tp = Get_Adjacent(TP, mapp, edges, 2)
    FPL = FP - two_hop_gp
    TPL = TP.union(FN.intersection(two_hop_tp))
    FN = FN - two_hop_tp

    TP, FP, FN, TN = len(TPL), len(FPL), len(FN), len(TN)

    prec, rec, fscore, FPR, TPR = calculate_metrics(TP, FP, FN, TN)
    print(f"True Positives: {TP}, False Positives: {FP}, False Negatives: {FN}")
    print(f"Precision: {round(prec, 2)}, Recall: {round(rec, 2)}, Fscore: {round(fscore, 2)}")
    
    return TPL, FPL

In [None]:
f = open("theia_test.txt")
data = f.read().split('\n')
data = [line.split('\t') for line in data]
df = pd.DataFrame (data, columns = ['actorID', 'actor_type','objectID','object','action','timestamp'])
df = df.dropna()
df.sort_values(by='timestamp', ascending=True,inplace=True)

In [None]:
df = add_attributes(df,"ta1-theia-e3-official-6r.json.8")

In [None]:
with open("data_files/theia.json", "r") as json_file:
    GT_mal = set(json.load(json_file))

data = df

phrases,labels,edges,mapp = prepare_graph(data)
nodes = [infer(x) for x in phrases]
nodes = np.array(nodes)  
et = time.time()

all_ids = list(data['actorID']) + list(data['objectID'])
all_ids = set(all_ids)

In [None]:
graph = Data(x=torch.tensor(nodes,dtype=torch.float).to(device),y=torch.tensor(labels,dtype=torch.long).to(device), edge_index=torch.tensor(edges,dtype=torch.long).to(device))
graph.n_id = torch.arange(graph.num_nodes)
flag = torch.tensor([True]*graph.num_nodes, dtype=torch.bool)

for m_n in range(20):
  model.load_state_dict(torch.load(f'trained_weights/theia/lword2vec_gnn_theia{m_n}_E3.pth',map_location=torch.device('cpu')))
  loader = NeighborLoader(graph, num_neighbors=[-1,-1], batch_size=5000)    
  for subg in loader:
      model.eval()
      out = model(subg.x, subg.edge_index)

      sorted, indices = out.sort(dim=1,descending=True)
      conf = (sorted[:,0] - sorted[:,1]) / sorted[:,0]
      conf = (conf - conf.min()) / conf.max()
    
      pred = indices[:,0]
      cond = (pred == subg.y) & (conf > 0.53)
      flag[subg.n_id[cond]] = torch.logical_and(flag[subg.n_id[cond]], torch.tensor([False]*len(flag[subg.n_id[cond]]), dtype=torch.bool))

index = utils.mask_to_index(flag).tolist()
ids = set([mapp[x] for x in index])
alerts = helper(set(ids),set(all_ids),GT_mal,edges,mapp) 

In [None]:
def traverse(ids, mapping, edges, hops, visited=None):
    if hops == 0:
        return set()

    if visited is None:
        visited = set()

    neighbors = set()
    for src, dst in zip(edges[0], edges[1]):
        src_mapped, dst_mapped = mapping[src], mapping[dst]

        if (src_mapped in ids and dst_mapped not in visited) or \
           (dst_mapped in ids and src_mapped not in visited):
            neighbors.add(src_mapped)
            neighbors.add(dst_mapped)

        visited.add(src_mapped)
        visited.add(dst_mapped)

    neighbors.difference_update(ids) 
    return ids.union(traverse(neighbors, mapping, edges, hops - 1, visited))

def load_data(file_path):
    with open(file_path, 'r') as file:
        return json.load(file)

def find_connected_alerts(start_alert, mapping, edges, depth, remaining_alerts):
    connected_path = traverse({start_alert}, mapping, edges, depth)
    return connected_path.intersection(remaining_alerts)

def generate_incident_graphs(alerts, edges, mapping, depth):
    incident_graphs = []
    remaining_alerts = set(alerts)

    while remaining_alerts:
        alert = remaining_alerts.pop()
        connected_alerts = find_connected_alerts(alert, mapping, edges, depth, remaining_alerts)

        if len(connected_alerts) > 1:
            incident_graphs.append(connected_alerts)
            remaining_alerts -= connected_alerts

    return incident_graphs
