In [None]:
from secgym.qagen.alert_graph import AlertGraph
import random
import numpy as np
import json
import os
def compute_overlap_score(path1, path2, alpha=3, beta=1):
    """
    Calculate the overlap score between two paths based on shared and unshared edges.

    Parameters:
        path (list): The first path as a sequence of nodes.
        other_path (list): The second path as a sequence of nodes.
        alpha (float): Weight for shared edges (positive contribution).
        beta (float): Weight for unshared edges (negative contribution).

    Returns:
        float: Overlap score.
    """
    if len(path1) <= 1 or len(path2) <= 1:
        return 0
    # Get edges for both paths
    edges1 = set((path1[i], path1[i + 1]) for i in range(len(path1) - 1))
    edges2 = set((path2[i], path2[i + 1]) for i in range(len(path2) - 1))
    
    # Shared and unshared edges
    shared_edges = edges1 & edges2
    if len(shared_edges) == 0:
        return 0
    unshared_edges = edges1 ^ edges2 

    # Compute score
    score = alpha * len(shared_edges) - beta * len(unshared_edges) / (len(edges1) + len(edges2))
    # bound: [-beta, alpha]
    # resacle to [0, 1]
    # score = (score + beta) / (alpha + beta)
 
    return score

def split_train_test(all_paths:dict, train_ratio:float = 0.9, trials=5):
    """
    all_path = [{'start_alert': int, 'end_alert': int}, ...]

    {
        'start_alert': 24,
        'end_alert': 11,
        'start_entities': [12],
        'end_entities': [9],
        'shortest_alert_path': [24, 9, 11]
    }
    """
    # construct a dictionary to map the path to the original dictionary
    random.shuffle(all_paths)
    path_to_dict = {}
    for path in all_paths:
        # extract the 0, 2, 4, 6, ... index of the path
        extracted = path['shortest_alert_path'][::2]
        path_to_dict[tuple(extracted)] = path
    
    # prepare the alert paths
    alert_paths = list(path_to_dict.keys())
    max_length = max(len(p) for p in alert_paths)
    total = len(alert_paths)
    train_len = int(total * train_ratio)
    test_len = total - train_len

    # length weight
    lweights = []
    for p in alert_paths:
        lweights.append(len(p) / max_length)
    lweights = np.array(lweights)
    lweights = lweights / lweights.sum()
    avg = lweights.mean()

    score_matrix = [[-1000] * len(alert_paths)  for _ in range(len(alert_paths))]

    # helper function to get the score
    def get_score(i, j):
        if score_matrix[i][j] != -1000:
            return score_matrix[i][j]
        if i==j:
            score_matrix[i][j] = 0
            return 0
        elif score_matrix[j][i] != -1000:
            score_matrix[i][j] = score_matrix[j][i]
            return score_matrix[i][j]
        score_matrix[i][j] = compute_overlap_score(alert_paths[i], alert_paths[j])
        return score_matrix[i][j]

    # random split and compare
    train_keys = []
    test_keys = []
    score_splits = {}
    for _ in range(trials):
        for i in range(total):
            if len(train_keys) >= train_len:
                test_keys += alert_paths[i:]
                break
            elif len(test_keys) >= test_len:
                train_keys += alert_paths[i:]
                break
            if random.random() > (lweights[i] / (lweights[i] + avg)):
                train_keys.append(alert_paths[i])
            else:
                test_keys.append(alert_paths[i])

        compare_score = 0
        for k in train_keys:
            for j in test_keys:
                compare_score += get_score(alert_paths.index(k), alert_paths.index(j))

        score_splits[compare_score] = (train_keys, test_keys)
        train_keys = []
        test_keys = []

    # assert len(final_train_set) + len(final_test_set) == total, f"Length mismatch: {len(final_train_set)} + {len(final_test_set)} != {total}"
    return score_splits, path_to_dict

graph_path = "/Users/kevin/Downloads/SecRL/secgym/qagen/graph_files"
qa_path = "/Users/kevin/Downloads/SecRL/secgym/qagen/graph_path"

train_total_count = 0
test_total_count = 0

median_score_path = "./media_split"
high_score_path = "./high_split"
low_score_path = "./low_split"

# create
os.makedirs(median_score_path, exist_ok=True)
os.makedirs(high_score_path, exist_ok=True)
os.makedirs(low_score_path, exist_ok=True)

def save_to_split(path, filename, train_keys, test_keys, path_to_dict):
    with open(os.path.join(path, filename.split(".")[0] + ".json"), "w") as f:
        train_set = [path_to_dict[p1] for p1 in train_keys]
        test_set = [path_to_dict[p2] for p2 in test_keys]
        json.dump({"train": train_set, "test": test_set}, f)

for filename in os.listdir(graph_path):
    if filename.endswith(".graphml"):
        # if "_5." in filename:
        #     continue
        print(filename)

        graphfile = graph_path + "/" + filename
        alert_graph = AlertGraph()
        alert_graph.load_graph_from_graphml(graphfile)
        all_paths = alert_graph.get_alert_paths(verbose=False)

        if len(all_paths) < 150:
            train_ratio = 0.288
            # trials = 20
        else:
            train_ratio = 1 - 100 / len(all_paths)
        print("Path length:", len(all_paths), "Train ratio:", train_ratio)

        score_splits, path_to_dict = split_train_test(all_paths, train_ratio, trials=50)

        # save high, low, median score
        scores = list(score_splits.keys())
        scores.sort()
        median = scores[len(scores) // 2]
        high = scores[-1]
        low = scores[0]

        save_to_split(median_score_path, filename, *score_splits[median], path_to_dict)
        save_to_split(high_score_path, filename, *score_splits[high], path_to_dict)
        save_to_split(low_score_path, filename, *score_splits[low], path_to_dict)
        
        print("Median score:", median, "High score:", high, "Low score:", low)
    

        # qafile = qa_path + "/" + filename.split(".")[0] + ".json"
        # if os.path.exists(qafile):
        #     # get best score
        #     with open(qafile, "r") as f:
        #         data = json.load(f)
        #         best_score = data["score"]
        #         print("New best score:", score, "Old best score:", best_score)
        #         if score > best_score:
        #             print("New best score is better, update")
        #             with open(qafile, "w") as f:
        #                 json.dump({"train": train_set, "test": test_set, "score": score}, f)
        #         else:
        #             print("Old best score is better, skip")
        # else:
        #     with open(qafile, "w") as f:
        #         json.dump({"train": train_set, "test": test_set, "score": score}, f)
            
        # print("Train set:", len(train_set), "Test set:", len(test_set))
        # train_total_count += len(train_set)
        # test_total_count += len(test_set)
        # print("-"*100)

print("Total test set:", test_total_count)

In [None]:
# read low split test sample and count number of examples for each length of shortest_alert_path
import json
import os
import numpy as np
import matplotlib.pyplot as plt

for filename in os.listdir(low_score_path):
    if filename.endswith(".json"):
        with open(os.path.join(low_score_path, filename), "r") as f:
            data = json.load(f)
            test_set = data["test"]
            lengths = [len(p["shortest_alert_path"]) for p in test_set]
            # count of each length and print
            count = {}
            for l in lengths:
                if l not in count:
                    count[l] = 0
                count[l] += 1

            # sort by length
            count = dict(sorted(count.items()))

            print(f"{filename}:")
            for k, v in count.items():
                print(f"Length {k}: {v} examples")

            print("============================================")
