# Flash Evaluation on DARPA E3 Cadets Dataset: 

This notebook is specifically designed for the evaluation of Flash on the DARPA E3 Cadets dataset. Notably, the Cadets dataset is characterized as a node-level dataset. In our analysis, Flash is configured to operate in a node-level setting to aptly assess this dataset. A key aspect to note is that the Cadets dataset lacks certain essential node attributes for specific node types. This limitation means that Flash cannot be operated in a decoupled mode with offline GNN embeddings for this dataset. Consequently, we employ an online GNN coupled with word2vec semantic embeddings to achieve effective evaluation results for this dataset.

## Dataset Access: 
- Access the Cadets dataset via the following link: [Cadets Dataset](https://drive.google.com/drive/folders/1QlbUFWAGq3Hpl8wVdzOdIoZLFxkII4EK).
- The dataset files will be downloaded automatically by the script.

## Data Parsing and Execution:
- The script is designed to automatically parse the downloaded data files.
- Execute all cells within this notebook to obtain the evaluation results.

## Model Training and Execution Flexibility:
- The notebook is configured to use pre-trained model weights by default.
- It also provides the option to set parameters for independently training Graph Neural Networks (GNNs) and word2vec models.
- These newly trained models can then be utilized for a comprehensive evaluation of the dataset.

Adhere to these steps for a detailed and effective analysis of the Cadets dataset using Flash.


In [1]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
import torch
from torch_geometric.data import Data
import os
import torch.nn.functional as F
# import orjson as json
import warnings
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
warnings.filterwarnings('ignore')
from torch_geometric.loader import NeighborLoader
import multiprocessing
import csv

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

import subprocess
gpu_mem = {int(x.split(',')[0]): int(x.split(',')[1]) for x in subprocess.check_output(
    ["nvidia-smi", "--query-gpu=index,memory.free", "--format=csv,noheader,nounits"], 
    encoding='utf-8').strip().split('\n')}
best_gpu = max(gpu_mem.items(), key=lambda x: x[1])[0]
device = torch.device(f'cuda:{best_gpu}' if torch.cuda.is_available() else 'cpu')
print(device)

%matplotlib inline

cuda:3


In [2]:
# import gdown
# urls = ["https://drive.google.com/file/d/1AcWrYiBmgAqp7DizclKJYYJJBQbnDMfb/view?usp=drive_link",
#       "https://drive.google.com/file/d/1EycO23tEvZVnN3VxOHZ7gdbSCwqEZTI1/view?usp=drive_link"]
# for url in urls:
#     gdown.download(url, quiet=False, use_cookies=False, fuzzy=True)

In [3]:
Train = True

In [4]:
from pprint import pprint
import gzip
from sklearn.manifold import TSNE
import json
import copy
import os

In [5]:
import re

def extract_uuid(line):
    pattern_uuid = re.compile(r'uuid\":\"(.*?)\"')
    return pattern_uuid.findall(line)

def extract_subject_type(line):
    pattern_type = re.compile(r'type\":\"(.*?)\"')
    return pattern_type.findall(line)

def show(file_path):
    print(f"Processing {file_path}")

def extract_edge_info(line):
    pattern_src = re.compile(r'subject\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}')
    pattern_dst1 = re.compile(r'predicateObject\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}')
    pattern_dst2 = re.compile(r'predicateObject2\":{\"com.bbn.tc.schema.avro.cdm18.UUID\":\"(.*?)\"}')
    pattern_type = re.compile(r'type\":\"(.*?)\"')
    pattern_time = re.compile(r'timestampNanos\":(.*?),')

    edge_type = extract_subject_type(line)[0]
    timestamp = pattern_time.findall(line)[0]
    src_id = pattern_src.findall(line)

    if len(src_id) == 0:
        return None, None, None, None, None

    src_id = src_id[0]
    dst_id1 = pattern_dst1.findall(line)
    dst_id2 = pattern_dst2.findall(line)

    if len(dst_id1) > 0 and dst_id1[0] != 'null':
        dst_id1 = dst_id1[0]
    else:
        dst_id1 = None

    if len(dst_id2) > 0 and dst_id2[0] != 'null':
        dst_id2 = dst_id2[0]
    else:
        dst_id2 = None

    return src_id, edge_type, timestamp, dst_id1, dst_id2

def process_data(file_path):
    id_nodetype_map = {}
    notice_num = 1000000
    for i in range(100):
        now_path = file_path + '.' + str(i)
        if i == 0:
            now_path = file_path
        if not os.path.exists(now_path):
            break

        with open(now_path, 'r') as f:
            show(now_path)
            cnt = 0
            for line in f:
                cnt += 1
                if cnt % notice_num == 0:
                    print(cnt)

                if 'com.bbn.tc.schema.avro.cdm18.Event' in line or 'com.bbn.tc.schema.avro.cdm18.Host' in line:
                    continue

                if 'com.bbn.tc.schema.avro.cdm18.TimeMarker' in line or 'com.bbn.tc.schema.avro.cdm18.StartMarker' in line:
                    continue

                if 'com.bbn.tc.schema.avro.cdm18.UnitDependency' in line or 'com.bbn.tc.schema.avro.cdm18.EndMarker' in line:
                    continue

                uuid = extract_uuid(line)[0]
                subject_type = extract_subject_type(line)

                if len(subject_type) < 1:
                    if 'com.bbn.tc.schema.avro.cdm18.MemoryObject' in line:
                        id_nodetype_map[uuid] = 'MemoryObject'
                        continue
                    if 'com.bbn.tc.schema.avro.cdm18.NetFlowObject' in line:
                        id_nodetype_map[uuid] = 'NetFlowObject'
                        continue
                    if 'com.bbn.tc.schema.avro.cdm18.UnnamedPipeObject' in line:
                        id_nodetype_map[uuid] = 'UnnamedPipeObject'
                        continue

                id_nodetype_map[uuid] = subject_type[0]

    return id_nodetype_map

def process_edges(file_path, id_nodetype_map):
    notice_num = 1000000
    not_in_cnt = 0

    for i in range(100):
        now_path = file_path + '.' + str(i)
        if i == 0:
            now_path = file_path
        if not os.path.exists(now_path):
            break

        with open(now_path, 'r') as f, open(now_path+'.txt', 'w') as fw:
            cnt = 0
            for line in f:
                cnt += 1
                if cnt % notice_num == 0:
                    print(cnt)

                if 'com.bbn.tc.schema.avro.cdm18.Event' in line:
                    src_id, edge_type, timestamp, dst_id1, dst_id2 = extract_edge_info(line)

                    if src_id is None or src_id not in id_nodetype_map:
                        not_in_cnt += 1
                        continue

                    src_type = id_nodetype_map[src_id]

                    if dst_id1 is not None and dst_id1 in id_nodetype_map:
                        dst_type1 = id_nodetype_map[dst_id1]
                        this_edge1 = f"{src_id}\t{src_type}\t{dst_id1}\t{dst_type1}\t{edge_type}\t{timestamp}\n"
                        fw.write(this_edge1)

                    if dst_id2 is not None and dst_id2 in id_nodetype_map:
                        dst_type2 = id_nodetype_map[dst_id2]
                        this_edge2 = f"{src_id}\t{src_type}\t{dst_id2}\t{dst_type2}\t{edge_type}\t{timestamp}\n"
                        fw.write(this_edge2)

def run_data_processing():
    os.system('tar -zxvf ta1-cadets-e3-official.json.tar.gz')
    os.system('tar -zxvf ta1-cadets-e3-official-2.json.tar.gz')

    path_list = ['ta1-cadets-e3-official.json', 'ta1-cadets-e3-official-2.json']

    for path in path_list:
        id_nodetype_map = process_data(path)
        process_edges(path, id_nodetype_map)

    os.system('cp ta1-cadets-e3-official.json.1.txt cadets_train.txt')
    os.system('cp ta1-cadets-e3-official-2.json.txt cadets_test.txt')


In [6]:
# run_data_processing()

In [7]:
def add_node_properties(nodes, node_id, properties):
    if node_id not in nodes:
        nodes[node_id] = []
    nodes[node_id].extend(properties)

def update_edge_index(edges, edge_index, index):
    for src_id, dst_id in edges:
        src = index[src_id]
        dst = index[dst_id]
        edge_index[0].append(src)
        edge_index[1].append(dst)

def prepare_graph(df):
    nodes, labels, edges = {}, {}, []
    dummies = {'SUBJECT_PROCESS': 0, 'FILE_OBJECT_FILE': 1, 'FILE_OBJECT_UNIX_SOCKET': 2, 
               'UnnamedPipeObject': 3, 'NetFlowObject': 4, 'FILE_OBJECT_DIR': 5}
    
    for _, row in df.iterrows():
        action = row["action"]
        properties = [row['exec'], action] + ([row['path']] if row['path'] else [])
        
        actor_id = row["actorID"]
        add_node_properties(nodes, actor_id, properties)
        labels[actor_id] = dummies[row['actor_type']]

        object_id = row["objectID"]
        add_node_properties(nodes, object_id, properties)
        labels[object_id] = dummies[row['object']]

        edges.append((actor_id, object_id))

    features, feat_labels, edge_index, index_map = [], [], [[], []], {}
    for node_id, props in nodes.items():
        features.append(props)
        feat_labels.append(labels[node_id])
        index_map[node_id] = len(features) - 1

    update_edge_index(edges, edge_index, index_map)

    return features, feat_labels, edge_index, list(index_map.keys()), index_map

dummies = {'SUBJECT_PROCESS': 0, 'FILE_OBJECT_FILE': 1, 'FILE_OBJECT_UNIX_SOCKET': 2, 
               'UnnamedPipeObject': 3, 'NetFlowObject': 4, 'FILE_OBJECT_DIR': 5}

from collections import defaultdict


def add_csv_edges_efficient(original_features, original_labels, original_edge_index, original_mapp, original_mapidx, csv_path, thres=450): 
    csv_edges = []
    csv_nodes = set()
    with open(csv_path, 'r', encoding='utf-8') as f:
        reader = csv.reader(f)
        for row in reader:
            src_id, dst_id = row[0], row[1]
            csv_edges.append((src_id, dst_id))
            csv_nodes.add(src_id)
            csv_nodes.add(dst_id)
    
    extended_features = original_features.copy()
    extended_labels = original_labels.copy()
    extended_edge_index = [original_edge_index[0].copy(), original_edge_index[1].copy()]
    extended_mapp = original_mapp.copy()
    extended_mapidx = original_mapidx.copy()

    degrees = defaultdict(int)
    for src_idx in extended_edge_index[0]:
        degrees[src_idx] += 1
    for dst_idx in extended_edge_index[1]:
        degrees[dst_idx] += 1
    
    all_degrees = list(degrees.values())
    all_degrees.sort(reverse=True)
    print(all_degrees[:3000])

    count1, count2, count3 = 0, 0, 0
    for src_id, dst_id in csv_edges:
        if src_id in extended_mapidx and dst_id in extended_mapidx:
            src_idx = extended_mapidx[src_id]
            dst_idx = extended_mapidx[dst_id]
            if degrees[src_idx] <= thres and degrees[dst_idx] <= thres:
                extended_edge_index[0].append(src_idx)
                extended_edge_index[1].append(dst_idx)
                count1 += 1
            else:
                count2 += 1
        else:
            count3 += 1
    print(f'Sucessfully add edges: {count1}\tPrune edges: {count2}\tFail to add edges: {count3}')
    
    return extended_features, extended_labels, extended_edge_index, extended_mapp, extended_mapidx

In [8]:
from torch_geometric.nn import GCNConv
from torch_geometric.nn import SAGEConv, GATConv
import torch.nn.functional as F
import torch.nn as nn

class GCN(torch.nn.Module):
    def __init__(self,in_channel,out_channel):
        super().__init__()
        self.conv1 = SAGEConv(in_channel, 4, normalize=True)
        self.conv2 = SAGEConv(4, out_channel, normalize=True)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)

        x = self.conv2(x, edge_index)
        return F.softmax(x, dim=1)

In [9]:
def visualize(h, color):
    z = TSNE(n_components=2).fit_transform(h.detach().cpu().numpy())

    plt.figure(figsize=(10,10))
    plt.xticks([])
    plt.yticks([])

    plt.scatter(z[:, 0], z[:, 1], s=70, c=color, cmap="Set2")
    plt.show()

In [10]:
from gensim.models.callbacks import CallbackAny2Vec
import gensim
from gensim.models import Word2Vec
from multiprocessing import Pool
from itertools import compress
from tqdm import tqdm
import time

class EpochSaver(CallbackAny2Vec):
    '''Callback to save model after each epoch.'''

    def __init__(self):
        self.epoch = 0

    def on_epoch_end(self, model):
        model.save('word2vec_cadets_E3.model')
        self.epoch += 1

In [11]:
class EpochLogger(CallbackAny2Vec):
    '''Callback to log information about training'''

    def __init__(self):
        self.epoch = 0

    def on_epoch_begin(self, model):
        print("Epoch #{} start".format(self.epoch))

    def on_epoch_end(self, model):
        print("Epoch #{} end".format(self.epoch))
        self.epoch += 1

In [12]:
logger = EpochLogger()
saver = EpochSaver()

In [13]:
def add_attributes(d,p):
    
    f = open(p)
    data = [json.loads(x) for x in f if "EVENT" in x]

    info = []
    for x in data:
        try:
            action = x['datum']['com.bbn.tc.schema.avro.cdm18.Event']['type']
        except:
            action = ''
        try:
            actor = x['datum']['com.bbn.tc.schema.avro.cdm18.Event']['subject']['com.bbn.tc.schema.avro.cdm18.UUID']
        except:
            actor = ''
        try:
            obj = x['datum']['com.bbn.tc.schema.avro.cdm18.Event']['predicateObject']['com.bbn.tc.schema.avro.cdm18.UUID']
        except:
            obj = ''
        try:
            timestamp = x['datum']['com.bbn.tc.schema.avro.cdm18.Event']['timestampNanos']
        except:
            timestamp = ''
        try:
            cmd = x['datum']['com.bbn.tc.schema.avro.cdm18.Event']['properties']['map']['exec']
        except:
            cmd = ''
        try:
            path = x['datum']['com.bbn.tc.schema.avro.cdm18.Event']['predicateObjectPath']['string']
        except:
            path = ''
        try:
            path2 = x['datum']['com.bbn.tc.schema.avro.cdm18.Event']['predicateObject2Path']['string']
        except:
            path2 = ''
        try:
            obj2 = x['datum']['com.bbn.tc.schema.avro.cdm18.Event']['predicateObject2']['com.bbn.tc.schema.avro.cdm18.UUID']
            info.append({'actorID':actor,'objectID':obj2,'action':action,'timestamp':timestamp,'exec':cmd, 'path':path2})
        except:
            pass

        info.append({'actorID':actor,'objectID':obj,'action':action,'timestamp':timestamp,'exec':cmd, 'path':path})

    rdf = pd.DataFrame.from_records(info).astype(str)
    d = d.astype(str)

    return d.merge(rdf,how='inner',on=['actorID','objectID','action','timestamp']).drop_duplicates()

In [14]:
if Train:
    f = open("../cadets_train.txt")
    data = f.read().split('\n')
    data = [line.split('\t') for line in data]
    df = pd.DataFrame (data, columns = ['actorID', 'actor_type','objectID','object','action','timestamp'])
    df = df.dropna()
    df.sort_values(by='timestamp', ascending=True,inplace=True)
    df = add_attributes(df,"../ta1-cadets-e3-official.json.1")
    phrases,labels,edges,mapp,mapidx = prepare_graph(df)
    csv_path = "../../flash-add-edge/cadets_train.csv"
    phrases, labels, edges, mapp, mapidx = add_csv_edges_efficient(
        phrases, labels, edges, mapp, mapidx, csv_path
    )

[484562, 166868, 86735, 80498, 80097, 79600, 75947, 65805, 53932, 40934, 35035, 34875, 27892, 26597, 25340, 22022, 20011, 19218, 17790, 16506, 16498, 16009, 15512, 15309, 12798, 12556, 12536, 12366, 11537, 11501, 9901, 9786, 9621, 9174, 8952, 8012, 8004, 7970, 7196, 7114, 7113, 6620, 6202, 6163, 6067, 6004, 5784, 5531, 5506, 5368, 4682, 4315, 4038, 4003, 4002, 3806, 3715, 3428, 3389, 3314, 3237, 3015, 2856, 2761, 2759, 2611, 2593, 2481, 2473, 2395, 2106, 2083, 2076, 2051, 2033, 2028, 2018, 2002, 2001, 2001, 2001, 1980, 1958, 1880, 1862, 1800, 1760, 1734, 1731, 1643, 1601, 1600, 1596, 1568, 1563, 1548, 1544, 1533, 1527, 1525, 1505, 1503, 1503, 1501, 1495, 1494, 1484, 1478, 1441, 1441, 1426, 1421, 1408, 1372, 1364, 1359, 1278, 1272, 1241, 1224, 1212, 1173, 1168, 1154, 1094, 1091, 1076, 1050, 1037, 1022, 1016, 1015, 1008, 984, 969, 968, 947, 929, 925, 919, 918, 910, 898, 898, 898, 897, 890, 880, 868, 866, 866, 865, 862, 854, 845, 833, 832, 816, 815, 815, 812, 811, 811, 811, 810, 810, 810,

In [15]:
from sklearn.utils import class_weight
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss

# model = GCN(30,6).to(device)
model = GCN(4,6).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

In [16]:
# if Train:
#     word2vec = Word2Vec(sentences=phrases, vector_size=30, window=5, min_count=1, workers=8,epochs=300,callbacks=[saver,logger])

In [17]:
# if Train:
#     word2vec = Word2Vec(sentences=phrases, vector_size=10, window=5, min_count=1, workers=8,epochs=300,callbacks=[saver,logger])

In [18]:
import math
import torch
import numpy as np
from gensim.models import Word2Vec

class PositionalEncoder:

    def __init__(self, d_model, max_len=100000):
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        self.pe = torch.zeros(max_len, d_model)
        self.pe[:, 0::2] = torch.sin(position * div_term)
        self.pe[:, 1::2] = torch.cos(position * div_term)

    def embed(self, x):
        return x + self.pe[:x.size(0)]


def infer(document):
    word_embeddings = [w2vmodel.wv[word] for word in document if word in w2vmodel.wv]
    
    if not word_embeddings:
        return np.zeros(4)
    
    output_embedding = torch.tensor(word_embeddings, dtype=torch.float)
    if len(document) < 100000:
        output_embedding = encoder.embed(output_embedding)

    output_embedding = output_embedding.detach().cpu().numpy()
    
    mean_embedding = np.mean(output_embedding, axis=0)
    
    return mean_embedding

# encoder = PositionalEncoder(30)
encoder = PositionalEncoder(4)
# w2vmodel = Word2Vec.load("trained_weights/cadets/word2vec_cadets_E3.model")
w2vmodel = Word2Vec.load("../word2vec_cadets_E3.model")

In [19]:
from torch_geometric import utils

if Train:
    l = np.array(labels)
    
    ngram_class_weights = [
        0.7339,  
        1.0213,   
        1.1339,   
        0.5000,   
        1.5000,   
        0.8151  
    ]
    class_weights = torch.tensor(ngram_class_weights,dtype=torch.float).to(device)
    
    criterion = CrossEntropyLoss(weight=class_weights,reduction='mean')

    nodes = [infer(x) for x in phrases]
    nodes = np.array(nodes)  

    graph = Data(x=torch.tensor(nodes,dtype=torch.float).to(device),y=torch.tensor(labels,dtype=torch.long).to(device), edge_index=torch.tensor(edges,dtype=torch.long).to(device))
    graph.n_id = torch.arange(graph.num_nodes)
    mask = torch.tensor([True]*graph.num_nodes, dtype=torch.bool)
    
    losses = []

    for m_n in range(22):

      loader = NeighborLoader(graph, num_neighbors=[-1,-1], batch_size=5000,input_nodes=mask)
      total_loss = 0
      for subg in loader:
          model.train()
          optimizer.zero_grad() 
          out = model(subg.x, subg.edge_index) 
          loss = criterion(out, subg.y) 
          loss.backward() 
          optimizer.step()      
          total_loss += loss.item() * subg.batch_size
      losses.append(total_loss / mask.sum().item())
      print(total_loss / mask.sum().item())

      loader = NeighborLoader(graph, num_neighbors=[-1,-1], batch_size=5000,input_nodes=mask)
      for subg in loader:
          model.eval()
          out = model(subg.x, subg.edge_index)

          sorted, indices = out.sort(dim=1,descending=True)
          conf = (sorted[:,0] - sorted[:,1]) / sorted[:,0]
          conf = (conf - conf.min()) / conf.max()

          pred = indices[:,0]
          cond = (pred == subg.y) | (conf >= 0.9)
          subg.n_id = subg.n_id.to(device)
          mask[subg.n_id[cond]] = False

      torch.save(model.state_dict(), f'lword2vec_gnn_cadets{m_n}_E3.pth')
      print(f'Model# {m_n}. {mask.sum().item()} nodes still misclassified \n')

1.7157109024642074


Model# 0. 63959 nodes still misclassified 



1.749706493538736
Model# 1. 23568 nodes still misclassified 



1.7980547039102552
Model# 2. 23525 nodes still misclassified 

1.7898650894002883


Model# 3. 23486 nodes still misclassified 

1.7750847627628725
Model# 4. 19483 nodes still misclassified 



1.7831676224935828
Model# 5. 18873 nodes still misclassified 



1.7607884162535874
Model# 6. 17940 nodes still misclassified 



1.7571397129375665
Model# 7. 9864 nodes still misclassified 

1.7715551816440653


Model# 8. 9736 nodes still misclassified 

1.7559255280255683
Model# 9. 9687 nodes still misclassified 



1.757235201062021
Model# 10. 9671 nodes still misclassified 

1.7432824892189618
Model# 11. 9636 nodes still misclassified 



1.7488973289666307
Model# 12. 9569 nodes still misclassified 

1.7364291804924892
Model# 13. 9343 nodes still misclassified 



1.7320710079682455
Model# 14. 7062 nodes still misclassified 

1.7317774552726368
Model# 15. 6806 nodes still misclassified 



1.7312906700969968
Model# 16. 6788 nodes still misclassified 

1.7254704291776972
Model# 17. 6740 nodes still misclassified 



1.7046938805622587
Model# 18. 6547 nodes still misclassified 

1.7153665789469736
Model# 19. 6542 nodes still misclassified 

1.7217129924471537
Model# 20. 6542 nodes still misclassified 

1.716755916313526
Model# 21. 6542 nodes still misclassified 



In [20]:
# plt.figure(figsize=(10, 6))
# plt.plot(range(len(losses)), losses, marker='o', linestyle='-', color='b')
# plt.title("GNN Training Loss Curve on Cadets")
# plt.xlabel("Epoch")
# plt.ylabel("Loss")
# plt.xticks(range(0, 22, 2))  # 显示0-19的epoch编号
# plt.grid(True)
# plt.show()

# print(losses)

In [21]:
from itertools import compress
from torch_geometric import utils

def Get_Adjacent(ids, mapp, edges, hops):
    if hops == 0:
        return set()
    
    neighbors = set()
    for edge in zip(edges[0], edges[1]):
        if any(mapp[node] in ids for node in edge):
            neighbors.update(mapp[node] for node in edge)

    if hops > 1:
        neighbors = neighbors.union(Get_Adjacent(neighbors, mapp, edges, hops - 1))
    
    return neighbors

def calculate_metrics(TP, FP, FN, TN):
    FPR = FP / (FP + TN) if FP + TN > 0 else 0
    TPR = TP / (TP + FN) if TP + FN > 0 else 0

    prec = TP / (TP + FP) if TP + FP > 0 else 0
    rec = TP / (TP + FN) if TP + FN > 0 else 0
    fscore = (2 * prec * rec) / (prec + rec) if prec + rec > 0 else 0

    return prec, rec, fscore, FPR, TPR

def helper(MP, all_pids, GP, edges, mapp):
    TP = MP.intersection(GP)
    FP = MP - GP
    FN = GP - MP
    TN = all_pids - (GP | MP)

    two_hop_gp = Get_Adjacent(GP, mapp, edges, 2)
    two_hop_tp = Get_Adjacent(TP, mapp, edges, 2)
    FPL = FP - two_hop_gp
    TPL = TP.union(FN.intersection(two_hop_tp))
    FN = FN - two_hop_tp

    TP, FP, FN, TN = len(TPL), len(FPL), len(FN), len(TN)

    prec, rec, fscore, FPR, TPR = calculate_metrics(TP, FP, FN, TN)
    print(f"True Positives: {TP}, True Negatives: {TN}, False Positives: {FP}, False Negatives: {FN}")
    print(f"Precision: {round(prec, 2)}, Recall: {round(rec, 2)}, Fscore: {round(fscore, 2)}")
    
    return TPL, FPL

In [22]:
f = open("../cadets_test.txt")
data = f.read().split('\n')
data = [line.split('\t') for line in data]
df = pd.DataFrame (data, columns = ['actorID', 'actor_type','objectID','object','action','timestamp'])
df = df.dropna()
df.sort_values(by='timestamp', ascending=True,inplace=True)

In [23]:
df = add_attributes(df,"../ta1-cadets-e3-official-2.json")

In [24]:
with open("../../data_files/cadets.json", "r") as json_file:
    GT_mal = set(json.load(json_file))

data = df

phrases,labels,edges,mapp,mapidx = prepare_graph(data)

csv_path = "../../flash-add-edge/cadets_test.csv"
phrases, labels, edges, mapp, mapidx = add_csv_edges_efficient(
    phrases, labels, edges, mapp, mapidx, csv_path
)

nodes = []
for i, x in enumerate(phrases):
    embedding = infer(x)
    nodes.append(embedding)
    
nodes = np.array(nodes)

all_ids = list(data['actorID']) + list(data['objectID'])
all_ids = set(all_ids)

[428308, 189402, 80187, 73509, 73018, 68393, 65941, 57255, 56958, 46775, 44153, 33784, 32798, 24986, 24408, 22444, 22073, 19477, 19188, 18747, 18545, 17688, 16774, 16723, 16667, 15720, 14944, 14282, 14213, 12877, 11311, 11084, 11010, 10002, 9983, 9255, 8615, 8574, 8430, 7087, 7079, 6514, 6220, 5925, 5603, 5492, 5412, 5306, 5226, 5185, 4903, 4816, 4348, 4073, 4044, 3771, 3590, 3576, 3539, 3536, 3359, 3350, 3278, 2997, 2966, 2668, 2660, 2657, 2599, 2597, 2538, 2516, 2513, 2507, 2438, 2432, 2431, 2426, 2423, 2423, 2421, 2419, 2419, 2403, 2321, 2307, 2305, 2250, 2145, 2062, 1892, 1845, 1840, 1811, 1809, 1802, 1799, 1785, 1773, 1769, 1768, 1768, 1748, 1748, 1740, 1739, 1723, 1670, 1656, 1646, 1622, 1593, 1590, 1567, 1561, 1555, 1549, 1537, 1523, 1496, 1494, 1487, 1465, 1462, 1460, 1452, 1443, 1442, 1436, 1434, 1413, 1406, 1386, 1380, 1371, 1362, 1353, 1351, 1336, 1334, 1323, 1321, 1319, 1316, 1300, 1300, 1297, 1292, 1289, 1281, 1281, 1265, 1256, 1255, 1253, 1250, 1245, 1234, 1222, 1220, 121

In [25]:
graph = Data(x=torch.tensor(nodes,dtype=torch.float).to(device),y=torch.tensor(labels,dtype=torch.long).to(device), edge_index=torch.tensor(edges,dtype=torch.long).to(device))
graph.n_id = torch.arange(graph.num_nodes)
flag = torch.tensor([True]*graph.num_nodes, dtype=torch.bool).to(device)

for m_n in range(22):
  # model.load_state_dict(torch.load(f'trained_weights/cadets/lword2vec_gnn_cadets{m_n}_E3.pth'))
  model.load_state_dict(torch.load(f'lword2vec_gnn_cadets{m_n}_E3.pth'))
  loader = NeighborLoader(graph, num_neighbors=[-1,-1], batch_size=5000)    
  for subg in loader:
      model.eval()
      out = model(subg.x, subg.edge_index)

      sorted, indices = out.sort(dim=1,descending=True)
      conf = (sorted[:,0] - sorted[:,1]) / sorted[:,0]
      conf = (conf - conf.min()) / conf.max()
    
      pred = indices[:,0]
      cond = (pred == subg.y)
      subg.n_id = subg.n_id.to(device)
      flag[subg.n_id[cond]] = torch.logical_and(flag[subg.n_id[cond]], torch.tensor([False]*len(flag[subg.n_id[cond]]), dtype=torch.bool).to(device))
index = utils.mask_to_index(flag).tolist()
ids = set([mapp[x] for x in index])
_ = helper(set(ids),set(all_ids),GT_mal,edges,mapp) 

True Positives: 12846, True Negatives: 336694, False Positives: 1393, False Negatives: 12
Precision: 0.9, Recall: 1.0, Fscore: 0.95


In [26]:
def traverse(ids, mapping, edges, hops, visited=None):
    if hops == 0:
        return set()

    if visited is None:
        visited = set()

    neighbors = set()
    for src, dst in zip(edges[0], edges[1]):
        src_mapped, dst_mapped = mapping[src], mapping[dst]

        if (src_mapped in ids and dst_mapped not in visited) or \
           (dst_mapped in ids and src_mapped not in visited):
            neighbors.add(src_mapped)
            neighbors.add(dst_mapped)

        visited.add(src_mapped)
        visited.add(dst_mapped)

    neighbors.difference_update(ids) 
    return ids.union(traverse(neighbors, mapping, edges, hops - 1, visited))

def load_data(file_path):
    with open(file_path, 'r') as file:
        return json.load(file)

def find_connected_alerts(start_alert, mapping, edges, depth, remaining_alerts):
    connected_path = traverse({start_alert}, mapping, edges, depth)
    return connected_path.intersection(remaining_alerts)

def generate_incident_graphs(alerts, edges, mapping, depth):
    incident_graphs = []
    remaining_alerts = set(alerts)

    while remaining_alerts:
        alert = remaining_alerts.pop()
        connected_alerts = find_connected_alerts(alert, mapping, edges, depth, remaining_alerts)

        if len(connected_alerts) > 1:
            incident_graphs.append(connected_alerts)
            remaining_alerts -= connected_alerts

    return incident_graphs