# Imports

In [1]:
import os
import pickle
from torch import nn
import torch
from openai import OpenAI
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel

# File Path Declaration

In [2]:
project_base_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.getcwd())))))
project_base_path

'/home/ANONYMOUS/projects/FALCON'

In [3]:
saved_v1_generated_data_path = os.path.join(project_base_path, "data/generation/snort/snort3-community-rules_v1.pkl")
saved_v1_generated_data_path

'/home/ANONYMOUS/projects/FALCON/data/generation/snort/snort3-community-rules_v1.pkl'

In [4]:
saved_v2_generated_data_path = os.path.join(project_base_path, "data/generation/snort/snort3-community-rules_v2.pkl")
saved_v2_generated_data_path

'/home/ANONYMOUS/projects/FALCON/data/generation/snort/snort3-community-rules_v2.pkl'

In [35]:
generated_rule_dir_path = os.path.join(project_base_path, "results/e2e/quantitative/snort/gpt_4o")
generated_rule_dir_path

'/home/ANONYMOUS/projects/FALCON/results/e2e/quantitative/snort/gpt_4o'

# Environment

In [6]:
open_ai_key = "OPENAI_KEY"
client = OpenAI(api_key=open_ai_key)
open_ai_model_name = "gpt-4o"
SEED = 42
DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
MAX_LEN = 512

In [7]:
def get_open_ai_response(prompt: str, model_name: str = open_ai_model_name) -> str:

  # Set the client with API key
  client = OpenAI(
    api_key=open_ai_key,  # This is the default and can be omitted
  )

  try:
      # Call the OpenAI API
      chat_completion = client.chat.completions.create(
          messages=[{"role": "system", "content": "You are a helpful assistant."},
                    {"role": "user", "content": prompt}],
          model=model_name
      )

      # Extract and return the assistant's reply
      return chat_completion.choices[0].message.content

  except Exception as e:
      return f"@@$$## Error communicating with OpenAI API: {str(e)}"

In [8]:
get_open_ai_response("Hello, which gpt version are you?")

"Hello! I am based on OpenAI's GPT-4 architecture. How can I assist you today?"

# Helper Functions

In [9]:
def load_from_pickle(file_path) -> dict:
    """
    Loads data from a pickle file.

    :param file_path: Path to the pickle file
    :return: Loaded data
    """
    try:
        with open(file_path, 'rb') as file:
            return pickle.load(file)
    except Exception as e:
        print(f"Error loading data from pickle: {e}")
        return None

In [10]:
def get_first_n_elements(dictionary: dict, n: int) -> dict:
    """
    Get the first n elements of a dictionary.

    :param dictionary: The input dictionary
    :param n: The number of elements to retrieve
    :return: A dictionary with the first n elements
    """
    return dict(list(dictionary.items())[:n])

In [11]:
def save_string_as_txt(directory_path, file_name, content):
    """
    Saves a given string as a .txt file in the specified directory.

    Args:
        directory_path (str): Path to the directory where the file should be saved.
        file_name (str): Desired name of the file (with or without .txt extension).
        content (str): The string content to be written to the file.

    Returns:
        str: Full path to the saved file if successful, otherwise an empty string.
    """
    if not file_name.lower().endswith('.txt'):
        file_name += '.txt'
    
    file_path = os.path.join(directory_path, file_name)

    try:
        os.makedirs(directory_path, exist_ok=True)  # Create directory if it doesn't exist
        with open(file_path, 'w', encoding='utf-8') as file:
            file.write(content)
        return file_path
    except PermissionError:
        print(f"Error: Permission denied to write to '{file_path}'.")
    except Exception as e:
        print(f"An error occurred while saving the file: {e}")

    return ""

# Data Generation

In [12]:
# Load the data back from the pickle file
loaded_v1_data = load_from_pickle(saved_v1_generated_data_path)
print(len(loaded_v1_data.keys()))

4017


In [13]:
snort_cti_sample_dict = get_first_n_elements(loaded_v1_data, 10)

In [14]:
# Load the data back from the pickle file
loaded_v2_data = load_from_pickle(saved_v2_generated_data_path)
print(len(loaded_v2_data.keys()))

4017


In [15]:
snort_cti_sample_dict

{'alert tcp $HOME_NET 2589 -> $EXTERNAL_NET any ( msg:"MALWARE-BACKDOOR - Dagger_1.4.0"; flow:to_client,established; content:"2|00 00 00 06 00 00 00|Drives|24 00|",depth 16; metadata:ruleset community; classtype:misc-activity; sid:105; rev:14; )': '    Title: Detection of Dagger 1.4.0 Backdoor Activity Over TCP\n\n    Threat Category: Malware – Backdoor\n\n    Threat Name: Dagger 1.4.0\n\n    Detection Summary:\n\n    This signature is designed to detect network traffic associated with the Dagger 1.4.0 backdoor. The traffic is characterized by a specific sequence of bytes ("2|00 00 00 06 00 00 00|Drives|24 00|") found within the first 16 bytes of the data payload. This communication occurs from an infected internal host to an external destination and typically indicates unauthorized remote access capabilities.\n\n    Rule Metadata\n    Classification: Misc Activity\n\n    Ruleset: Community\n\n    Rule Logic Breakdown\n    Alert Type: alert\n\n    Protocol: tcp\n\n    Source IP: $HOME_

In [16]:
snorts, ctis = zip(*snort_cti_sample_dict.items())
snorts = list(snorts)
ctis = list(ctis)

In [17]:
len(snorts), len(ctis)

(10, 10)

In [18]:
def format_cti_snort_data_to_training_data(data: list[dict]) -> list[tuple]:
    """
    Format the CTI Snort data into training data.

    :param data: The data to format
    :return: Formatted training data
    """
    training_data = []
    for dataset in data:
        for key, value in dataset.items():
            training_data.append((key, value))
    return training_data

In [19]:
# Sample Dataset Format (list of (anchor, positive) sentence pairs)
full_dataset = format_cti_snort_data_to_training_data([loaded_v1_data, loaded_v2_data])
print(len(full_dataset))

8034


In [20]:
def remove_10_test_samples(training_data: list[tuple], test_pairs: dict) -> list[tuple]:
    # Extract all test keys and values into sets for quick lookup
    test_keys = set(test_pairs.keys())
    test_values = set(test_pairs.values())
    
    # Filter training data
    filtered_data = [(key, value) for key, value in training_data if key not in test_keys and value not in test_values]
    
    return filtered_data

In [21]:
# Sample Dataset Format (list of (anchor, positive) sentence pairs)
full_dataset = remove_10_test_samples(full_dataset, snort_cti_sample_dict)
print(len(full_dataset))

8014


In [22]:
# Split into training and testing sets (80% train, 20% test)
train_pairs, test_pairs = train_test_split(full_dataset, test_size=0.1, random_state=SEED)

In [23]:
len(test_pairs)

802

# Dataset Class

In [24]:
# Custom Dataset
class ContrastiveDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        anchor, positive = self.data[idx]
        encoded = self.tokenizer([anchor, positive], padding="max_length", truncation=True,
                                 max_length=MAX_LEN, return_tensors="pt")
        return {
            "input_ids_a": encoded["input_ids"][0],
            "attention_mask_a": encoded["attention_mask"][0],
            "input_ids_b": encoded["input_ids"][1],
            "attention_mask_b": encoded["attention_mask"][1],
        }

In [25]:
# Bi-Encoder Model
class SentenceEncoder(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)

    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        embeddings = outputs.last_hidden_state[:, 0]  # CLS token
        return nn.functional.normalize(embeddings, p=2, dim=1)  # Normalize for cosine similarity


# Evaluation Functions

In [26]:
def generate_rule_from_cti_prompt(input_cti: str) -> str:

  rule_generation_prompt = f"""

    You are a cybersecurity expert tasked with performing Snort rule generation for a given Cyber Threat Intelligence (CTI).
    There is a sample task input and output provided below.
    
    Sample CTI Input and corresponding Snort Output:

    CTI Input:
        
      Title: Detection of QAZ Worm Client Login Activity over TCP Port 7597

      Threat Category: Malware – Backdoor

      Threat Name: QAZ Worm

      Detection Summary:

      This signature is designed to detect network traffic associated with the QAZ Worm, specifically its client login activity. The worm exhibits characteristic behavior by initiating a connection and transmitting a unique identifier string (qazwsx.hsq) to a remote server over TCP port 7597. This communication typically indicates the presence of a backdoor that allows unauthorized access to infected systems.

      Rule Metadata
      Classification: Misc Activity

      Ruleset: Community

      Rule Logic Breakdown
      Alert Type: alert

      Protocol: tcp

      Source IP: $EXTERNAL_NET (any IP address outside the local trusted network)

      Source Port: any

      Destination IP: $HOME_NET (any IP address inside the local trusted network)

      Destination Port: 7597 (known port used by the QAZ worm)

      Flow: to_server, established
      (Traffic must be flowing to a server and part of an established connection)

      Content Match: "qazwsx.hsq"
      (String in the payload that identifies the worm’s presence)

      Message: "MALWARE-BACKDOOR QAZ Worm Client Login access"

      Technical Details
      Port 7597 is not a standard well-known port and is leveraged by the QAZ Worm for backdoor communications.

      The content string "qazwsx.hsq" is a unique identifier used by the worm’s client when connecting to a command-and-control server or to another infected host.

      Detection relies on the presence of this string within an established TCP session directed to a host on the internal network.

      Indicators of Compromise (IOCs)
      String Pattern: qazwsx.hsq

      Destination Port: 7597/tcp

      Recommended Actions
      Block or restrict traffic on port 7597 at the perimeter firewall.

      Investigate any internal systems that initiate or receive such connections.

      Perform malware scanning and forensic analysis on potentially compromised hosts.

      Update endpoint and network defense signatures to ensure coverage against this and similar threats.
      
    Snort Output:
        
      alert tcp $EXTERNAL_NET any -> $HOME_NET 7597 ( msg:"MALWARE-BACKDOOR QAZ Worm Client Login access"; flow:to_server,established; content:"qazwsx.hsq"; metadata:ruleset community; classtype:misc-activity; sid:108; rev:12; )


    Generate Snort from the provided CTI. Do not include anything that is not provided.
    Do not print anything like sure here is the CTI or anything else. Only print the CTI. 

    CTI Input: 
    
      {input_cti}

    Snort Output:

  """

  return rule_generation_prompt

# Test Code

In [27]:
test_ctis = [i[1] for i in test_pairs]
gt_rules = [i[0] for i in test_pairs]

In [28]:
len(test_ctis), len(gt_rules)

(802, 802)

In [29]:
test_ctis[0]
gt_rules[0]

'alert tcp $EXTERNAL_NET $FILE_DATA_PORTS -> $HOME_NET any ( msg:"FILE-IDENTIFY Microsoft Compound File Binary v3 file magic detected"; flow:to_client,established; file_data; content:"|D0 CF 11 E0 A1 B1 1A E1|"; content:">|00 03 00|",within 4,distance 16; flowbits:set,file.ole; flowbits:noalert; metadata:policy balanced-ips alert,policy connectivity-ips alert,policy max-detect-ips alert,policy security-ips alert,ruleset community; service:ftp-data,http,imap,pop3; classtype:misc-activity; sid:16474; rev:27; )'

In [31]:
prompt = generate_rule_from_cti_prompt(test_ctis[0])
print(prompt)



    You are a cybersecurity expert tasked with performing Snort rule generation for a given Cyber Threat Intelligence (CTI).
    There is a sample task input and output provided below.
    
    Sample CTI Input and corresponding Snort Output:

    CTI Input:
        
      Title: Detection of QAZ Worm Client Login Activity over TCP Port 7597

      Threat Category: Malware – Backdoor

      Threat Name: QAZ Worm

      Detection Summary:

      This signature is designed to detect network traffic associated with the QAZ Worm, specifically its client login activity. The worm exhibits characteristic behavior by initiating a connection and transmitting a unique identifier string (qazwsx.hsq) to a remote server over TCP port 7597. This communication typically indicates the presence of a backdoor that allows unauthorized access to infected systems.

      Rule Metadata
      Classification: Misc Activity

      Ruleset: Community

      Rule Logic Breakdown
      Alert Type: alert

      

In [32]:
test_rule = get_open_ai_response(prompt)
print(test_rule)

```
alert tcp $EXTERNAL_NET $FILE_DATA_PORTS -> $HOME_NET any (msg:"FILE-IDENTIFY Microsoft Compound File Binary v3 file magic detected"; flow:to_client,established; content:"|D0 CF 11 E0 A1 B1 1A E1|"; content:"|00 03 00|"; distance:16; within:4; metadata:ruleset community; classtype:misc-activity; sid:109; rev:1;)
```


# Generate Rule from CTI

In [54]:
generated_rules = []

inference_counter = 0
for cti in tqdm(test_ctis, "Generating Snort rules from CTIs..."):
  prompt = generate_rule_from_cti_prompt(cti)
  rule = get_open_ai_response(prompt)
  generated_rules.append(rule)
  file_name = f"quantitative_eval_snort_rule_{inference_counter}.txt"
  save_string_as_txt(generated_rule_dir_path, file_name, rule)
  inference_counter += 1

Generating Snort rules from CTIs...: 100%|██████████| 802/802 [23:06<00:00,  1.73s/it]  


In [62]:
len(generated_rules)

802

In [33]:
def load_generated_rules_from_directory(directory_path):
    """
    Reads all .txt files in the given directory and returns their contents as a list of strings.

    Args:
        directory_path (str): Path to the directory containing the txt files.

    Returns:
        List[str]: List of generated rules.
    """
    generated_rules = []

    for filename in sorted(os.listdir(directory_path)):
        if filename.endswith(".txt"):
            file_path = os.path.join(directory_path, filename)
            with open(file_path, "r", encoding="utf-8") as f:
                rule_text = f.read().strip()
                generated_rules.append(rule_text)

    return generated_rules


In [36]:
generated_rules = load_generated_rules_from_directory(generated_rule_dir_path)

# CTI-Rule Semantic Evaluation

In [37]:
def compute_dot_product_matrix_batched(model, tokenizer, test_snorts, test_ctis, batch_size=64):
    # Tokenize snorts once (since all CTIs will be compared to them)
    tokenized_snorts = tokenizer(test_snorts, return_tensors="pt", padding=True, max_length=MAX_LEN, truncation=True)
    input_ids_snorts = tokenized_snorts["input_ids"].to(DEVICE)
    attention_mask_snorts = tokenized_snorts["attention_mask"].to(DEVICE)

    with torch.no_grad():
        emb_snorts = model(input_ids_snorts, attention_mask_snorts)  # (802, dim)
        emb_snorts = emb_snorts.detach()

    # Prepare output tensor for all dot products
    num_ctis = len(test_ctis)
    dot_product_matrix = []

    for i in range(0, num_ctis, batch_size):
        batch_ctis = test_ctis[i:i + batch_size]
        tokenized_ctis = tokenizer(batch_ctis, return_tensors="pt", padding=True, max_length=MAX_LEN, truncation=True)
        input_ids_ctis = tokenized_ctis["input_ids"].to(DEVICE)
        attention_mask_ctis = tokenized_ctis["attention_mask"].to(DEVICE)

        with torch.no_grad():
            emb_ctis = model(input_ids_ctis, attention_mask_ctis)  # (B, dim)
            dot_product_batch = torch.matmul(emb_ctis, emb_snorts.T)  # (B, 802)
            dot_product_matrix.append(dot_product_batch.cpu())

        # Cleanup
        del input_ids_ctis, attention_mask_ctis, emb_ctis, dot_product_batch
        torch.cuda.empty_cache()

    # Concatenate batches into full matrix
    dot_product_matrix = torch.cat(dot_product_matrix, dim=0)  # (802, 802)
    return dot_product_matrix


In [38]:
RUN = 2
FINE_TUNED_MODEL_NAME = "all-mpnet-base-v2"
MODEL_NAME = f"/data/common/models/sentence-transformers/{FINE_TUNED_MODEL_NAME}"
FINE_TUNED_MODEL_STATE_NAME = f"contrastive_encoder_r{RUN}.pt"
MODEL_SAVE_PATH = os.path.join(project_base_path, f"script/fine_tuning/bi-encoder/snort/{FINE_TUNED_MODEL_NAME}/{FINE_TUNED_MODEL_STATE_NAME}")


# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Load model
model = SentenceEncoder(MODEL_NAME).to(DEVICE)
model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))

<All keys matched successfully>

In [None]:
dot_product_matrix_test = compute_dot_product_matrix_batched(
    model=model,
    tokenizer=tokenizer,
    test_snorts=generated_rules,
    test_ctis=test_ctis,
    batch_size=256
)

In [None]:
import numpy as np

def extract_sigmoid_diagonal(dot_product_matrix: torch.Tensor):
    """
    Extracts principal diagonal from a dot-product matrix,
    applies sigmoid to each value, and returns the sigmoid list.
    Also prints the mean of sigmoid values.

    Args:
        dot_product_matrix (torch.Tensor): A square matrix of shape [N x N]

    Returns:
        List[float]: Sigmoid values of diagonal entries
    """
    assert dot_product_matrix.shape[0] == dot_product_matrix.shape[1], "Matrix must be square."

    # Step 1: Extract diagonal
    diag_values = dot_product_matrix.diag()  # shape: (N,)

    # Step 2: Apply sigmoid
    sigmoid_values = torch.sigmoid(diag_values)

    # Step 3: Convert to list and compute mean
    sigmoid_list = sigmoid_values.tolist()
    mean_value = torch.mean(sigmoid_values).item()

    # Output
    print(f"Mean of sigmoid(diagonal values): {mean_value:.4f}")
    return sigmoid_list

In [65]:
sigmoid_diagonal_scores = extract_sigmoid_diagonal(dot_product_matrix_test)

Mean of sigmoid(diagonal values): 0.7217


# Ragas 

In [68]:
open_ai_key = "OPENAI_KEY"
os.environ['OPENAI_API_KEY'] = open_ai_key

In [None]:
import numpy as np
from langchain.embeddings import OpenAIEmbeddings

def compute_dot_product_matrix_openai(test_snorts, test_ctis, batch_size=50):
    embedder = OpenAIEmbeddings()  # Uses text-embedding-ada-002 by default

    # Step 1: Get embeddings for all Snort rules
    snort_embeddings = []
    for i in range(0, len(test_snorts), batch_size):
        batch = test_snorts[i:i + batch_size]
        snort_embeddings.extend(embedder.embed_documents(batch))  # List of vectors

    snort_embeddings = np.array(snort_embeddings)  # Shape: (N, D)
    snort_embeddings_norm = np.linalg.norm(snort_embeddings, axis=1, keepdims=True)

    # Step 2: Compute batched dot products with CTIs
    dot_product_matrix = []

    for i in range(0, len(test_ctis), batch_size):
        batch = test_ctis[i:i + batch_size]
        cti_embeddings = embedder.embed_documents(batch)
        cti_embeddings = np.array(cti_embeddings)
        cti_embeddings_norm = np.linalg.norm(cti_embeddings, axis=1, keepdims=True)

        # Normalize and compute dot product
        sim_matrix = np.dot(cti_embeddings, snort_embeddings.T) / (
            cti_embeddings_norm @ snort_embeddings_norm.T
        )
        dot_product_matrix.append(sim_matrix)

    dot_product_matrix = np.vstack(dot_product_matrix)  # Final shape: (len(test_ctis), len(test_snorts))
    return dot_product_matrix

In [71]:
def extract_diagonal(dot_product_matrix: torch.Tensor):
    """
    Extracts principal diagonal from a dot-product matrix,
    applies sigmoid to each value, and returns the sigmoid list.
    Also prints the mean of sigmoid values.

    Args:
        dot_product_matrix (torch.Tensor): A square matrix of shape [N x N]

    Returns:
        List[float]: Sigmoid values of diagonal entries
    """
    assert dot_product_matrix.shape[0] == dot_product_matrix.shape[1], "Matrix must be square."

    # Step 1: Extract diagonal
    diag_values = dot_product_matrix.diag()  # shape: (N,)

    # Step 3: Convert to list and compute mean
    diag_list = diag_values.tolist()
    mean_value = torch.mean(diag_values).item()

    # Output
    print(f"Mean of sigmoid(diagonal values): {mean_value:.4f}")
    return diag_list

In [72]:
dot_product_matrix_test = compute_dot_product_matrix_openai(
    test_snorts=generated_rules,
    test_ctis=test_ctis,
    batch_size=50
)

In [73]:
diagonal_scores_openai = extract_diagonal(torch.tensor(dot_product_matrix_test))

Mean of sigmoid(diagonal values): 0.8648


# Bert

In [76]:
from bert_score import score

def compute_bert_scores(cti_list, generated_rules, lang="en", model_type="bert-base-uncased"):
    """
    Computes BERTScore between CTI descriptions and generated Snort rules.

    Args:
        cti_list (List[str]): List of CTI strings.
        generated_rules (List[str]): Corresponding generated Snort rule strings.
        lang (str): Language (default = "en").
        model_type (str): BERT model to use (default = DeBERTa-MNLI).

    Returns:
        dict: Precision, Recall, F1 scores (averaged) and all individual F1s.
    """
    assert len(cti_list) == len(generated_rules), "Mismatch in CTI and rule count."

    P, R, F1 = score(generated_rules, cti_list, lang=lang, model_type=model_type, verbose=True)

    mean_p = P.mean().item()
    mean_r = R.mean().item()
    mean_f1 = F1.mean().item()

    print(f"\nBERTScore Results:\nPrecision: {mean_p:.4f} | Recall: {mean_r:.4f} | F1: {mean_f1:.4f}")
    return {
        "precision": mean_p,
        "recall": mean_r,
        "f1": mean_f1,
        "f1_scores": F1.tolist()  # optional: return individual F1s
    }


In [77]:
bert_score_results = compute_bert_scores(test_ctis, generated_rules)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

calculating scores...
computing bert embedding.


  0%|          | 0/25 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/13 [00:00<?, ?it/s]

done in 6.88 seconds, 116.57 sentences/sec

BERTScore Results:
Precision: 0.6868 | Recall: 0.5501 | F1: 0.6106
