# Imports

In [1]:
import os
import gc
import pickle
import torch
from torch import nn
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import statistics

# Path Declaration

In [2]:
project_base_path = os.path.dirname(os.path.dirname(os.path.dirname(os.getcwd())))
project_base_path

'/home/ANONYMOUS/projects/FALCON'

In [3]:
cti_snort_eval_data_path = os.path.join(project_base_path, "data/evaluation/cti-rule/snort/cti_snort_eval_data.pkl")
cti_snort_eval_data_path

'/home/ANONYMOUS/projects/FALCON/data/evaluation/cti-rule/snort/cti_snort_eval_data.pkl'

# Misc Functions

In [4]:
def load_from_pickle(file_path):
    """
    Loads data from a pickle file.

    :param file_path: Path to the pickle file
    :return: Loaded data
    """
    try:
        with open(file_path, 'rb') as file:
            return pickle.load(file)
    except Exception as e:
        print(f"Error loading data from pickle: {e}")
        return None

In [5]:
def map_subset_indices(full_list, subset_list):
    """
    Maps each string in the subset list to its index in the full list.

    Args:
        full_list (list of str): The complete list of strings.
        subset_list (list of str): A subset of strings present in the full list.

    Returns:
        dict: A dictionary with subset strings as keys and their indices in the full list as values.
    """
    index_map = {}
    for item in subset_list:
        try:
            index_map[item] = full_list.index(item)
        except ValueError:
            # Just in case the subset contains a string not found in full_list
            index_map[item] = -1
    return index_map

In [6]:
def evaluate_topk_match(gt_indices, sorted_pred_indices, top_k):
    top_k_preds = set(sorted_pred_indices[:top_k])
    matched = top_k_preds.intersection(set(gt_indices))
    return 100 * len(matched) / len(gt_indices) if gt_indices else 0

In [7]:
def reciprocal_rank(gt_indices, sorted_pred_indices):
    for rank, idx in enumerate(sorted_pred_indices, start=1):
        if idx in gt_indices:
            return 1.0 / rank
    return 0.0

In [8]:
def average_precision(gt_indices, sorted_pred_indices):
    hits, score = 0, 0.0
    for rank, idx in enumerate(sorted_pred_indices, start=1):
        if idx in gt_indices:
            hits += 1
            score += hits / rank
    return score / len(gt_indices) if gt_indices else 0.0


# Environment Setup

In [9]:
# ⚙️ Config
MODEL_NAME = "/data/common/models/sentence-transformers/all-MiniLM-L6-v2"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# RUN = 2
###########################
MAX_LEN = 512
FINE_TUNED_MODEL_NAME = "all-MiniLM-L6-v2"
# FINE_TUNED_MODEL_STATE_NAME = f"contrastive_encoder_r{RUN}.pt"
SEED = 42
torch.manual_seed(SEED)
if torch.cuda.is_available():  
    torch.cuda.manual_seed_all(SEED)
        
# MODEL_LOAD_PATH = os.path.join(project_base_path, f"script/fine_tuning/bi-encoder/snort/{FINE_TUNED_MODEL_NAME}/{FINE_TUNED_MODEL_STATE_NAME}")

In [10]:
# Bi-Encoder Model
class SentenceEncoder(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)

    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        embeddings = outputs.last_hidden_state[:, 0]  # CLS token
        return nn.functional.normalize(embeddings, p=2, dim=1)  # Normalize for cosine similarity

In [11]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Load model
model = SentenceEncoder(MODEL_NAME).to(DEVICE)
# model.load_state_dict(torch.load(MODEL_LOAD_PATH, map_location=DEVICE))

# Data Preparation

### Load Data

In [12]:
# Load the data back from the pickle file
cti_snort_eval_data = load_from_pickle(cti_snort_eval_data_path)
print(len(cti_snort_eval_data.keys()))

802


### Pre-processing

In [13]:
consolidated_dummy_snort_rules = []
for cti, rules in cti_snort_eval_data.items():
    consolidated_dummy_snort_rules.extend(rules)

In [14]:
len(consolidated_dummy_snort_rules)

3063

# Evaluation

In [15]:
test_ctis = list(cti_snort_eval_data.keys())

total_recall, total_map = 0, 0
recall_k_list = []
map_score_list = []

for cti in tqdm(test_ctis, "Evaluating CTI-Snort Semantic Scorer"):
    result_idx = map_subset_indices(consolidated_dummy_snort_rules, cti_snort_eval_data[cti])
    gt_indices = list(result_idx.values())
    
    with torch.no_grad():
        tokenized_cti = tokenizer(cti, return_tensors="pt", padding=True, max_length=MAX_LEN, truncation=True)
        tokenized_dummy_rules = tokenizer(consolidated_dummy_snort_rules, return_tensors="pt", padding=True, max_length=MAX_LEN, truncation=True)
        
        input_ids_tokenized_cti = tokenized_cti['input_ids'].to(DEVICE)
        attention_mask_tokenized_cti = tokenized_cti['attention_mask'].to(DEVICE)
        input_ids_tokenized_dummy_rules = tokenized_dummy_rules['input_ids'].to(DEVICE)
        attention_mask_tokenized_dummy_rules = tokenized_dummy_rules['attention_mask'].to(DEVICE)
        
        emb_tokenized_cti = model(input_ids_tokenized_cti, attention_mask_tokenized_cti)
        emb_tokenized_dummy_rules = model(input_ids_tokenized_dummy_rules, attention_mask_tokenized_dummy_rules)
       
        dot_product_matrix = torch.matmul(emb_tokenized_cti, emb_tokenized_dummy_rules.T)
        
        similarity_scores = dot_product_matrix[0]
        sorted_indices = torch.argsort(similarity_scores, descending=True).tolist()
        
        # print(sorted_indices)
        # print(gt_indices)
        
        # Metric Calculations
        k = len(gt_indices)
        recall_k = evaluate_topk_match(gt_indices, sorted_indices, 10)
        recall_k_list.append(recall_k)
        map_score = average_precision(gt_indices, sorted_indices)
        map_score_list.append(map_score)

        # Accumulate
        total_recall += recall_k
        total_map += map_score

        # Print for each CTI if needed
        # print(f"\nCTI: {cti[:50]}...")
        # print(f"GT Indices: {gt_indices}")
        # print(f"Top-{k} Recall: {recall_k:.2f}%, MAP: {map_score:.4f}")
    
    # Free up memory
    del emb_tokenized_cti, emb_tokenized_dummy_rules
    del input_ids_tokenized_cti, attention_mask_tokenized_cti
    del input_ids_tokenized_dummy_rules, attention_mask_tokenized_dummy_rules
    torch.cuda.empty_cache()   # Clear GPU memory 

Evaluating CTI-Snort Semantic Scorer: 100%|██████████| 802/802 [10:29<00:00,  1.27it/s]


### Top - k

In [16]:
# Final Aggregated Scores
n = len(test_ctis)
print("\n=== Overall Evaluation Results ===")
print(f"Average Recall@K: {total_recall / n:.2f}%")
print(f"Mean Average Precision (MAP): {total_map / n:.4f}")
    
# Standard Deviation of Scores
recall_std = statistics.stdev(recall_k_list)
map_std = statistics.stdev(map_score_list)
print(f"Recall@K Standard Deviation: {recall_std:.4f}")
print(f"MAP Standard Deviation: {map_std:.4f}")  


=== Overall Evaluation Results ===
Average Recall@K: 18.04%
Mean Average Precision (MAP): 0.2047
Recall@K Standard Deviation: 25.1929
MAP Standard Deviation: 0.2528


### Top - 10

In [16]:
# Final Aggregated Scores
n = len(test_ctis)
print("\n=== Overall Evaluation Results ===")
print(f"Average Recall@K: {total_recall / n:.2f}%")
print(f"Mean Average Precision (MAP): {total_map / n:.4f}")
    
# Standard Deviation of Scores
recall_std = statistics.stdev(recall_k_list)
map_std = statistics.stdev(map_score_list)
print(f"Recall@K Standard Deviation: {recall_std:.4f}")
print(f"MAP Standard Deviation: {map_std:.4f}")  


=== Overall Evaluation Results ===
Average Recall@K: 27.80%
Mean Average Precision (MAP): 0.2047
Recall@K Standard Deviation: 31.6807
MAP Standard Deviation: 0.2528


### Top - 20

In [16]:
# Final Aggregated Scores
n = len(test_ctis)
print("\n=== Overall Evaluation Results ===")
print(f"Average Recall@K: {total_recall / n:.2f}%")
print(f"Mean Average Precision (MAP): {total_map / n:.4f}")
    
# Standard Deviation of Scores
recall_std = statistics.stdev(recall_k_list)
map_std = statistics.stdev(map_score_list)
print(f"Recall@K Standard Deviation: {recall_std:.4f}")
print(f"MAP Standard Deviation: {map_std:.4f}")  


=== Overall Evaluation Results ===
Average Recall@K: 35.40%
Mean Average Precision (MAP): 0.2047
Recall@K Standard Deviation: 34.2419
MAP Standard Deviation: 0.2528
