# Imports

In [1]:
import os
import pickle
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
import matplotlib.pyplot as plt
import seaborn as sns
from torch.nn.functional import normalize
import numpy as np
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from IPython.display import clear_output
import time  # Optional, to slow down updates a little

# Path Declaration

In [15]:
project_base_path = os.path.dirname(os.path.dirname(os.path.dirname(os.getcwd())))
project_base_path

'/home/ANONYMOUS/projects/FALCON'

In [None]:
saved_v1_generated_data_path = os.path.join(project_base_path, "data/generation/snort/snort3-community-rules_v1.pkl")
saved_v1_generated_data_path

'/home/ANONYMOUSOUS/projects/FALCON/data/generation/snort/snort3-community-rules_v1.pkl'

In [None]:
saved_v2_generated_data_path = os.path.join(project_base_path, "data/generation/snort/snort3-community-rules_v2.pkl")
saved_v2_generated_data_path

'/home/ANONYMOUSOUS/projects/FALCON/data/generation/snort/snort3-community-rules_v2.pkl'

# Environment Setup

In [32]:
SEED = 42

In [33]:
open_ai_key = "OPENAI_KEY"
os.environ['OPENAI_API_KEY'] = open_ai_key

# Load Dataset

In [34]:
def load_from_pickle(file_path) -> dict:
    """
    Loads data from a pickle file.

    :param file_path: Path to the pickle file
    :return: Loaded data
    """
    try:
        with open(file_path, 'rb') as file:
            return pickle.load(file)
    except Exception as e:
        print(f"Error loading data from pickle: {e}")
        return None

In [35]:
def get_first_n_elements(dictionary: dict, n: int) -> dict:
    """
    Get the first n elements of a dictionary.

    :param dictionary: The input dictionary
    :param n: The number of elements to retrieve
    :return: A dictionary with the first n elements
    """
    return dict(list(dictionary.items())[:n])

In [None]:
# Load the data back from the pickle file
loaded_v1_data = load_from_pickle(saved_v1_generated_data_path)
print(len(loaded_v1_data.keys()))

4017


In [37]:
snort_cti_sample_dict = get_first_n_elements(loaded_v1_data, 10)

In [None]:
# Load the data back from the pickle file
loaded_v2_data = load_from_pickle(saved_v2_generated_data_path)
print(len(loaded_v2_data.keys()))

4017


In [39]:
snort_cti_sample_dict

{'alert tcp $HOME_NET 2589 -> $EXTERNAL_NET any ( msg:"MALWARE-BACKDOOR - Dagger_1.4.0"; flow:to_client,established; content:"2|00 00 00 06 00 00 00|Drives|24 00|",depth 16; metadata:ruleset community; classtype:misc-activity; sid:105; rev:14; )': '    Title: Detection of Dagger 1.4.0 Backdoor Activity Over TCP\n\n    Threat Category: Malware – Backdoor\n\n    Threat Name: Dagger 1.4.0\n\n    Detection Summary:\n\n    This signature is designed to detect network traffic associated with the Dagger 1.4.0 backdoor. The traffic is characterized by a specific sequence of bytes ("2|00 00 00 06 00 00 00|Drives|24 00|") found within the first 16 bytes of the data payload. This communication occurs from an infected internal host to an external destination and typically indicates unauthorized remote access capabilities.\n\n    Rule Metadata\n    Classification: Misc Activity\n\n    Ruleset: Community\n\n    Rule Logic Breakdown\n    Alert Type: alert\n\n    Protocol: tcp\n\n    Source IP: $HOME_

In [40]:
snorts, ctis = zip(*snort_cti_sample_dict.items())
snorts = list(snorts)
ctis = list(ctis)

In [41]:
len(snorts), len(ctis)

(10, 10)

In [42]:
def format_cti_snort_data_to_training_data(data: list[dict]) -> list[tuple]:
    """
    Format the CTI Snort data into training data.

    :param data: The data to format
    :return: Formatted training data
    """
    training_data = []
    for dataset in data:
        for key, value in dataset.items():
            training_data.append((key, value))
    return training_data

In [None]:
# Sample Dataset Format (list of (anchor, positive) sentence pairs)
full_dataset = format_cti_snort_data_to_training_data([loaded_v1_data, loaded_v2_data])
print(len(full_dataset))

8034


In [44]:
def remove_10_test_samples(training_data: list[tuple], test_pairs: dict) -> list[tuple]:
    # Extract all test keys and values into sets for quick lookup
    test_keys = set(test_pairs.keys())
    test_values = set(test_pairs.values())
    
    # Filter training data
    filtered_data = [(key, value) for key, value in training_data if key not in test_keys and value not in test_values]
    
    return filtered_data

In [45]:
# Sample Dataset Format (list of (anchor, positive) sentence pairs)
full_dataset = remove_10_test_samples(full_dataset, snort_cti_sample_dict)
print(len(full_dataset))

8014


In [46]:
# Split into training and testing sets (80% train, 20% test)
train_pairs, test_pairs = train_test_split(full_dataset, test_size=0.1, random_state=SEED)

# Training Setup

In [47]:
# Custom Dataset
class ContrastiveDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        anchor, positive = self.data[idx]
        encoded = self.tokenizer([anchor, positive], padding="max_length", truncation=True,
                                 max_length=MAX_LEN, return_tensors="pt")
        return {
            "input_ids_a": encoded["input_ids"][0],
            "attention_mask_a": encoded["attention_mask"][0],
            "input_ids_b": encoded["input_ids"][1],
            "attention_mask_b": encoded["attention_mask"][1],
        }

In [48]:
# Bi-Encoder Model
class SentenceEncoder(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)

    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        embeddings = outputs.last_hidden_state[:, 0]  # CLS token
        return nn.functional.normalize(embeddings, p=2, dim=1)  # Normalize for cosine similarity


In [49]:
# Contrastive Loss (InfoNCE / NT-Xent)
def contrastive_loss(emb_a, emb_b, temperature=0.05):
    similarity_matrix = torch.matmul(emb_a, emb_b.T) / temperature
    labels = torch.arange(len(emb_a)).to(emb_a.device)
    return nn.CrossEntropyLoss()(similarity_matrix, labels)


# Scaling Function

In [50]:
def analyze_dot_product_matrix(matrix: torch.Tensor):
    """
    Computes statistics for the principal diagonal and off-diagonal values of a dot product matrix.

    Args:
        matrix (torch.Tensor): A square 2D tensor (dot product matrix).

    Returns:
        Tuple[dict, dict]: (diagonal_stats, off_diagonal_stats)
    """
    assert matrix.ndim == 2 and matrix.shape[0] == matrix.shape[1], "Matrix must be square"

    diag_vals = torch.diag(matrix)
    all_vals = matrix.flatten()
    off_diag_mask = ~torch.eye(matrix.size(0), dtype=torch.bool, device=matrix.device)
    off_diag_vals = matrix[off_diag_mask]

    def stats(tensor):
        return {
            "mean": tensor.mean().item(),
            "max": tensor.max().item(),
            "min": tensor.min().item(),
            "std": tensor.std(unbiased=False).item(),  # population std
        }

    return stats(diag_vals), stats(off_diag_vals)

In [51]:
import numpy as np
from langchain.embeddings import OpenAIEmbeddings

def compute_dot_product_matrix_openai(test_snorts, test_ctis, batch_size=50):
    embedder = OpenAIEmbeddings()  # Uses text-embedding-ada-002 by default

    # Step 1: Get embeddings for all Snort rules
    snort_embeddings = []
    for i in range(0, len(test_snorts), batch_size):
        batch = test_snorts[i:i + batch_size]
        snort_embeddings.extend(embedder.embed_documents(batch))  # List of vectors

    snort_embeddings = np.array(snort_embeddings)  # Shape: (N, D)
    snort_embeddings_norm = np.linalg.norm(snort_embeddings, axis=1, keepdims=True)

    # Step 2: Compute batched dot products with CTIs
    dot_product_matrix = []

    for i in range(0, len(test_ctis), batch_size):
        batch = test_ctis[i:i + batch_size]
        cti_embeddings = embedder.embed_documents(batch)
        cti_embeddings = np.array(cti_embeddings)
        cti_embeddings_norm = np.linalg.norm(cti_embeddings, axis=1, keepdims=True)

        # Normalize and compute dot product
        sim_matrix = np.dot(cti_embeddings, snort_embeddings.T) / (
            cti_embeddings_norm @ snort_embeddings_norm.T
        )
        dot_product_matrix.append(sim_matrix)

    dot_product_matrix = np.vstack(dot_product_matrix)  # Final shape: (len(test_ctis), len(test_snorts))
    return dot_product_matrix


## 10 Validation Set

### Run - 0

In [31]:
diag_stats, off_diag_stats = analyze_dot_product_matrix(dot_product_matrix)

print("Diagonal Stats:", diag_stats)
print("Off-Diagonal Stats:", off_diag_stats)

Diagonal Stats: {'mean': 0.9303933382034302, 'max': 0.9669086337089539, 'min': 0.885727047920227, 'std': 0.02710534632205963}
Off-Diagonal Stats: {'mean': 0.1781265139579773, 'max': 0.46763429045677185, 'min': -0.04564410820603371, 'std': 0.10627235472202301}


### Run - 1

In [30]:
diag_stats, off_diag_stats = analyze_dot_product_matrix(dot_product_matrix)

print("Diagonal Stats:", diag_stats)
print("Off-Diagonal Stats:", off_diag_stats)

Diagonal Stats: {'mean': 0.9284042716026306, 'max': 0.9657737016677856, 'min': 0.8474253416061401, 'std': 0.03653645142912865}
Off-Diagonal Stats: {'mean': 0.10324714332818985, 'max': 0.2915942668914795, 'min': -0.056207071989774704, 'std': 0.08546042442321777}


### Run - 2

In [30]:
diag_stats, off_diag_stats = analyze_dot_product_matrix(dot_product_matrix)

print("Diagonal Stats:", diag_stats)
print("Off-Diagonal Stats:", off_diag_stats)

Diagonal Stats: {'mean': 0.9453257918357849, 'max': 0.9653378129005432, 'min': 0.9332507252693176, 'std': 0.011390076950192451}
Off-Diagonal Stats: {'mean': 0.11288418620824814, 'max': 0.3074943721294403, 'min': -0.050255388021469116, 'std': 0.08911170065402985}


## Test Set

In [52]:
test_snorts = [i[0] for i in test_pairs]
test_ctis = [i[1] for i in test_pairs]

### Run - 0

In [29]:
dot_product_matrix_test = compute_dot_product_matrix_batched(
    model=model,
    tokenizer=tokenizer,
    test_snorts=test_snorts,
    test_ctis=test_ctis,
    batch_size=256
)

diag_stats, off_diag_stats = analyze_dot_product_matrix(dot_product_matrix_test)

print("Diagonal Stats:", diag_stats)
print("Off-Diagonal Stats:", off_diag_stats)

Diagonal Stats: {'mean': 0.931696355342865, 'max': 0.9832119941711426, 'min': 0.6484097838401794, 'std': 0.04163140058517456}
Off-Diagonal Stats: {'mean': 0.07574468851089478, 'max': 0.9773631691932678, 'min': -0.2572978138923645, 'std': 0.08740384131669998}


### Run - 1

In [26]:
dot_product_matrix_test = compute_dot_product_matrix_batched(
    model=model,
    tokenizer=tokenizer,
    test_snorts=test_snorts,
    test_ctis=test_ctis,
    batch_size=256
)

diag_stats, off_diag_stats = analyze_dot_product_matrix(dot_product_matrix_test)

print("Diagonal Stats:", diag_stats)
print("Off-Diagonal Stats:", off_diag_stats)

Diagonal Stats: {'mean': 0.9362938404083252, 'max': 0.9846035242080688, 'min': 0.45150303840637207, 'std': 0.04456904157996178}
Off-Diagonal Stats: {'mean': 0.03960946574807167, 'max': 0.9812132120132446, 'min': -0.27644082903862, 'std': 0.08802148699760437}


### Run - 2

In [29]:
dot_product_matrix_test = compute_dot_product_matrix_batched(
    model=model,
    tokenizer=tokenizer,
    test_snorts=test_snorts,
    test_ctis=test_ctis,
    batch_size=256
)

diag_stats, off_diag_stats = analyze_dot_product_matrix(dot_product_matrix_test)

print("Diagonal Stats:", diag_stats)
print("Off-Diagonal Stats:", off_diag_stats)

Diagonal Stats: {'mean': 0.9444412589073181, 'max': 0.9851973056793213, 'min': 0.34648361802101135, 'std': 0.04367481917142868}
Off-Diagonal Stats: {'mean': 0.03628416731953621, 'max': 0.9788450598716736, 'min': -0.30357807874679565, 'std': 0.08703625202178955}


# Semantic Evaluation

In [53]:
import numpy as np
from sklearn.metrics import f1_score

def evaluate_similarity_with_auto_threshold_numpy(dot_product_matrix: np.ndarray):
    """
    Evaluates diagonal recall and best F1-score based on thresholded sigmoid scores
    using OpenAI-generated NumPy dot-product matrix.

    Args:
        dot_product_matrix (np.ndarray): Square similarity matrix (N x N)

    Returns:
        dict: {
            'recall_diag': float,
            'f1_best': float,
            'best_threshold': float,
            'sigmoid_min': float,
            'sigmoid_max': float,
        }
    """
    assert dot_product_matrix.ndim == 2 and dot_product_matrix.shape[0] == dot_product_matrix.shape[1], \
        "Input must be a square matrix."

    N = dot_product_matrix.shape[0]

    # Apply sigmoid to scores
    sigmoid_scores = 1 / (1 + np.exp(-dot_product_matrix))

    # Diagonal Recall: how often the highest score is at the correct (diagonal) position
    recall_diag = np.mean([np.argmax(dot_product_matrix[i]) == i for i in range(N)])

    # Prepare labels and scores
    labels = []
    flat_scores = []

    for i in range(N):
        for j in range(N):
            labels.append(1 if i == j else 0)
            flat_scores.append(sigmoid_scores[i, j])

    # Search for best threshold to maximize F1
    thresholds = np.linspace(min(flat_scores), max(flat_scores), num=100)
    best_f1 = 0.0
    best_threshold = 0.0

    for t in thresholds:
        preds = [1 if s >= t else 0 for s in flat_scores]
        f1 = f1_score(labels, preds, zero_division=0)
        if f1 > best_f1:
            best_f1 = f1
            best_threshold = t

    return {
        "recall_diag": recall_diag,
        "f1_best": best_f1,
        "best_threshold": best_threshold,
        "sigmoid_min": min(flat_scores),
        "sigmoid_max": max(flat_scores),
    }


In [54]:
test_snorts = [i[0] for i in test_pairs]
test_ctis = [i[1] for i in test_pairs]

In [55]:
len(test_snorts), len(test_ctis)

(802, 802)

## Pre-trained

In [56]:
dot_matrix = compute_dot_product_matrix_openai(test_snorts, test_ctis)
metrics = evaluate_similarity_with_auto_threshold_numpy(dot_matrix)

for k, v in metrics.items():
    print(f"{k}: {v:.4f}")


  embedder = OpenAIEmbeddings()  # Uses text-embedding-ada-002 by default


recall_diag: 0.9102
f1_best: 0.6300
best_threshold: 0.7068
sigmoid_min: 0.6656
sigmoid_max: 0.7136
