In [1]:
from __future__ import annotations

import json, hashlib, random, time, sys, platform
from dataclasses import dataclass, asdict, replace
from pathlib import Path
from typing import Any, Dict, List, Tuple

import numpy as np
import pandas as pd
import networkx as nx
import heapq
import hashlib

from scipy import sparse
from sklearn.pipeline import Pipeline

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm

from sklearn.compose import ColumnTransformer

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import TruncatedSVD
from sklearn.preprocessing import Normalizer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.ensemble import IsolationForest

import matplotlib.pyplot as plt

import copy

from dataclasses import replace

from dataclasses import replace
from collections import deque
import pandas as pd
import networkx as nx

# BASELINE

In [2]:
@dataclass(frozen=True)
class BaselineConfig:
    # Data / split
    golden_id: str = "episode_016"
    dataset_path: str = "episodes_all_baseline.parquet"   
    alert_csv_path: str = "Golden/Alert_1.csv"         

    # Feature extraction
    max_text_features: int = 5000
    min_df: int = 5
    ngram_range: tuple = (1, 2)
    text_col: str = "masked_message_cl"
    cat_cols: tuple = ()

    # Isolation Forest
    if_n_estimators: int = 100
    if_contamination: str = "auto"
    if_random_state: int = 42

    # Graph
    graph_max_gap_host: int = 60
    graph_max_gap_actor: int = 120              
    use_host_edges: bool = True          
    use_actor_edges: bool = True

    # Alert attach
    attach_strategy: str = "prefer_sshd_success_then_closest"

    # RCA walk
    rca_max_nodes: int = 200
    rca_max_hops: int | None = None
    rca_max_back_seconds: int = 30 * 60
    rca_forward_slack_seconds: int = 60
    rca_priority_mode: str = "baseline"

    # Evaluation @k
    ks: tuple = (5, 10, 20, 50)

cfg_baseline = BaselineConfig()

In [None]:
dataset_path = PROJECT_DIR / "data" / cfg_baseline.dataset_path
assert dataset_path.exists(), f"Missing dataset parquet: {dataset_path}"

episodes_df = pd.read_parquet(dataset_path)
episodes_df = episodes_df.reset_index(drop=True)
episodes_df["node_id"] = episodes_df.index.astype(int)

dataset_fingerprint = file_sha256_12(dataset_path)

cfg_fingerprint = hashlib.sha256(json.dumps(asdict(cfg_baseline), sort_keys=True).encode("utf-8")).hexdigest()[:12]
RUN_DIR = RUNS_DIR / f"run_{cfg_fingerprint}_{dataset_fingerprint}"
RUN_DIR.mkdir(parents=True, exist_ok=True)

save_json(asdict(cfg_baseline), RUN_DIR / "baseline_config.json")
save_json(get_env_info(), RUN_DIR / "environment.json")
save_json({"dataset_path": str(dataset_path), "dataset_sha256_12": dataset_fingerprint}, RUN_DIR / "dataset_fingerprint.json")

print("Loaded:", dataset_path)
print("Rows:", len(episodes_df))
print("Run dir:", RUN_DIR)

In [5]:
NOTEBOOK_DIR = Path.cwd()
PROJECT_DIR = NOTEBOOK_DIR.parent
dataset_path = PROJECT_DIR / "data" / cfg_baseline.dataset_path
assert dataset_path.exists(), f"Missing dataset parquet: {dataset_path}"

episodes_df = pd.read_parquet(dataset_path)
episodes_df = episodes_df.reset_index(drop=True)
episodes_df["node_id"] = episodes_df.index.astype(int)

In [6]:
@dataclass
class BaselineFeatureBuilder:
    cfg_baseline: BaselineConfig

    def __post_init__(self):
        self.text_vectorizer = TfidfVectorizer(
            max_features=self.cfg_baseline.max_text_features,
            min_df=self.cfg_baseline.min_df,
            ngram_range=self.cfg_baseline.ngram_range,
        )

        transformers = [
            ("text", self.text_vectorizer, self.cfg_baseline.text_col),
        ]

        cat_cols = list(self.cfg_baseline.cat_cols) if getattr(self.cfg_baseline, "cat_cols", None) else []
        if len(cat_cols) > 0:
            try:
                self.cat_encoder = OneHotEncoder(handle_unknown="ignore", sparse_output=True)
            except TypeError:
                self.cat_encoder = OneHotEncoder(handle_unknown="ignore", sparse=True)

            transformers.append(("cat", self.cat_encoder, cat_cols))
        
        self.preprocessor = ColumnTransformer(transformers=transformers)
        self.pipeline = Pipeline(steps=[("preprocess", self.preprocessor)])

    def fit_transform(self, df: pd.DataFrame) -> sparse.csr_matrix:
        return self.pipeline.fit_transform(df)

    def transform(self, df: pd.DataFrame) -> sparse.csr_matrix:
        return self.pipeline.transform(df)

train_df = episodes_df[episodes_df["episode_id"] != cfg_baseline.golden_id].copy()
golden_df = episodes_df[episodes_df["episode_id"] == cfg_baseline.golden_id].copy()

train_df[cfg_baseline.text_col] = train_df[cfg_baseline.text_col].fillna("").astype(str)
golden_df[cfg_baseline.text_col] = golden_df[cfg_baseline.text_col].fillna("").astype(str)

assert len(golden_df) > 0, f"No rows for golden_id={cfg_baseline.golden_id}"

feat_builder = BaselineFeatureBuilder(cfg_baseline=cfg_baseline)
X_train = feat_builder.fit_transform(train_df)
X_gold  = feat_builder.transform(golden_df)

iso = IsolationForest(
    n_estimators=cfg_baseline.if_n_estimators,
    contamination=cfg_baseline.if_contamination,
    random_state=cfg_baseline.if_random_state,
    n_jobs=-1,
)
iso.fit(X_train)

raw_scores = iso.decision_function(X_gold)   # higher = more normal
golden_df["baseline_score_iso"] = -raw_scores  # higher = more suspicious

episodes_df.loc[golden_df.index, "baseline_score_iso"] = golden_df["baseline_score_iso"].values
print(golden_df["baseline_score_iso"].describe())

count    29998.000000
mean        -0.117049
std          0.034239
min         -0.154611
25%         -0.139395
50%         -0.128928
75%         -0.101724
max          0.109014
Name: baseline_score_iso, dtype: float64


In [7]:
def add_normalized_score(df: pd.DataFrame, col: str = "baseline_score_iso", out_col: str = "score_norm") -> pd.DataFrame:
    df = df.copy()
    s = df[col]
    mask = s.notna()
    if mask.sum() == 0:
        df[out_col] = 0.0
        return df
    s_valid = s[mask]
    s_min, s_max = float(s_valid.min()), float(s_valid.max())
    if s_max == s_min:
        df[out_col] = 0.5
        return df
    score = (s - s_min) / (s_max - s_min)
    score[~mask] = 0.0
    df[out_col] = score
    return df

golden_df = add_normalized_score(golden_df, col="baseline_score_iso", out_col="score_norm")

In [8]:
episodes_df.loc[golden_df.index, "score_norm"] = golden_df["score_norm"].values

In [9]:
def build_episode_graph(df: pd.DataFrame, cfg_baseline: BaselineConfig) -> nx.DiGraph:
    df = df.sort_values("timestamp").copy()
    G = nx.DiGraph()

    for _, row in df.iterrows():
        nid = int(row["node_id"])
        G.add_node(
            nid,
            timestamp=row["timestamp"],
            stream=row["stream"],
            masked_message_cl=row["masked_message_cl"],
            actor_ip_anon=row.get("actor_ip_anon", "none"),
            host_anon=row.get("host_anon", "none"),
            baseline_score_iso=float(row.get("baseline_score_iso", np.nan)),
            score_norm=float(row.get("score_norm", 0.0)),
        )

    def add_temporal_edges_for_key(key_col: str, max_gap: int, edge_type: str):
        for key, group in df.groupby(key_col):
            if str(key) in ("none", "", "nan", "None"):
                continue
            group = group.sort_values("timestamp")
            prev_nid, prev_ts = None, None
            for _, r in group.iterrows():
                nid = int(r["node_id"])
                ts = r["timestamp"]
                if prev_nid is not None:
                    dt = (ts - prev_ts).total_seconds()
                    if 0 <= dt <= max_gap:
                        G.add_edge(prev_nid, nid, kind=edge_type, dt=float(dt))
                prev_nid, prev_ts = nid, ts

    if cfg_baseline.use_actor_edges:
        add_temporal_edges_for_key("actor_ip_anon", cfg_baseline.graph_max_gap_actor, "actor_ip_temporal")
    if cfg_baseline.use_host_edges:
        add_temporal_edges_for_key("host_anon", cfg_baseline.graph_max_gap_host, "host_temporal")

    return G

G_golden = build_episode_graph(golden_df, cfg_baseline)
print("Golden graph:", G_golden.number_of_nodes(), "nodes,", G_golden.number_of_edges(), "edges")

Golden graph: 29998 nodes, 33485 edges


In [11]:
def save_json(obj: Any, path: Path) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as f:
        json.dump(obj, f, indent=2, ensure_ascii=False, default=str)

In [13]:
ALERT_NODE_ID = "ALERT_NODE_ID"

def short_hash(x):
    return "none" if pd.isna(x) or not str(x).strip() else hashlib.sha256(str(x).encode()).hexdigest()[:12]

alert_path = PROJECT_DIR / cfg_baseline.alert_csv_path
assert alert_path.exists(), f"Missing alert csv: {alert_path}"
alert_df = pd.read_csv(alert_path)
alert_row = alert_df.iloc[0]

alert_ts = pd.to_datetime(alert_row["@timestamp"], utc=True)
alert_src_ip = alert_row.get("source.ip", None)
alert_name   = alert_row.get("kibana.alert.rule.name", "")
alert_reason = alert_row.get("kibana.alert.reason", "")
alert_desc   = alert_row.get("description", "")

def attach_alert_node(G: nx.DiGraph, df: pd.DataFrame, alert_ts: pd.Timestamp, alert_src_ip=None,
                      alert_name: str = "", alert_reason: str = "", alert_desc: str = "") -> list[int]:
    G.add_node(
        ALERT_NODE_ID,
        timestamp=alert_ts,
        stream="alert",
        masked_message_cl=f"[alert] name={alert_name} reason={alert_reason} desc={alert_desc}",
        baseline_score_iso=0.0,
        score_norm=1.0,
    )

    attacker_ip_str = None
    if alert_src_ip is not None:
        attacker_ip_str = str(alert_src_ip).strip()
        if not attacker_ip_str or attacker_ip_str.lower() in ("nan", "none"):
            attacker_ip_str = None

    df_sorted = df.sort_values("timestamp").copy()
    df_sorted["dt_abs"] = (df_sorted["timestamp"] - alert_ts).abs().dt.total_seconds()

    best = pd.DataFrame()

    # 1) prefer sshd success from attacker ip
    if attacker_ip_str is not None:
        attacker_hash = short_hash(attacker_ip_str)
        cand1 = df_sorted[
            (df_sorted["stream"] == "system.auth") &
            (df_sorted["masked_message_cl"].astype(str).str.contains("out=success", na=False)) &
            (df_sorted["masked_message_cl"].astype(str).str.contains("proc=sshd", na=False)) &
            (df_sorted.get("actor_ip_anon", "") == attacker_hash)
        ]
        if not cand1.empty:
            best = cand1.nsmallest(1, "dt_abs")

    # 2) else any sshd success
    if best.empty:
        cand2 = df_sorted[
            (df_sorted["stream"] == "system.auth") &
            (df_sorted["masked_message_cl"].astype(str).str.contains("out=success", na=False)) &
            (df_sorted["masked_message_cl"].astype(str).str.contains("proc=sshd", na=False))
        ]
        if not cand2.empty:
            best = cand2.nsmallest(1, "dt_abs")

    # 3) else closest event
    if best.empty:
        best = df_sorted.nsmallest(1, "dt_abs")

    attached = [int(best.iloc[0]["node_id"])]
    for nid in attached:
        dt = (G.nodes[nid]["timestamp"] - alert_ts).total_seconds()
        G.add_edge(ALERT_NODE_ID, nid, kind="alert_to_log", dt=float(dt))

    return attached

attached_log_ids = attach_alert_node(G_golden, golden_df, alert_ts, alert_src_ip, alert_name, alert_reason, alert_desc)
print("Attached node_ids:", attached_log_ids)

Attached node_ids: [560913]


In [14]:
def rca_walk(G: nx.DiGraph, cfg_baseline: BaselineConfig, alert_node_id: str = ALERT_NODE_ID) -> tuple[nx.DiGraph, set, list]:
    alert_ts = G.nodes[alert_node_id]["timestamp"]

    def get_score_norm(node_id):
        if node_id == alert_node_id:
            return 1.0
        try:
            return float(G.nodes[node_id].get("score_norm", 0.0))
        except Exception:
            return 0.0

    selected = set([alert_node_id])
    visited  = set([alert_node_id])
    ranked   = [alert_node_id]

    heap = []
    for succ in G.successors(alert_node_id):
        ts = G.nodes[succ]["timestamp"]
        dt = (alert_ts - ts).total_seconds()
        if dt > cfg_baseline.rca_max_back_seconds:
            continue
        if dt < -cfg_baseline.rca_forward_slack_seconds:
            continue
        score = get_score_norm(succ)
        heapq.heappush(heap, ((-score, abs(dt)), 1, succ))

    while heap and len(selected) < cfg_baseline.rca_max_nodes:
        (_, hops, node_id) = heapq.heappop(heap)
        if node_id in visited:
            continue
        if cfg_baseline.rca_max_hops is not None and hops > cfg_baseline.rca_max_hops:
            continue

        visited.add(node_id)
        selected.add(node_id)
        ranked.append(node_id)

        for pred in G.predecessors(node_id):
            if pred in visited or pred == alert_node_id:
                continue
            ts = G.nodes[pred]["timestamp"]
            dt = (alert_ts - ts).total_seconds()
            if dt > cfg_baseline.rca_max_back_seconds:
                continue
            if dt < -cfg_baseline.rca_forward_slack_seconds:
                continue
            score = get_score_norm(pred)
            heapq.heappush(heap, ((-score, abs(dt)), hops + 1, pred))

    subG = G.subgraph(selected).copy()
    return subG, selected, ranked

subG_baseline, selected_nodes_baseline, rca_ranked_baseline = rca_walk(G_golden, cfg_baseline, alert_node_id=ALERT_NODE_ID)
print("RCA subgraph:")
print("  Nodes:", subG_baseline.number_of_nodes())
print("  Edges:", subG_baseline.number_of_edges())

RCA subgraph:
  Nodes: 200
  Edges: 227


# CONTRASTIVE

In [15]:
@dataclass(frozen=True)
class ContrastiveRunConfig:
    # Data
    golden_id: str = "episode_016"
    dataset_path: str = "episodes_all_baseline.parquet"   
    text_col: str = "masked_message_cl"
    alert_csv_path: str = "Golden/Alert_1.csv"

    # Pair mining
    pair_window_sec: int = 90
    max_pos_per_anchor: int = 2
    max_anchors_per_episode: int = 10000 
    max_neg_per_anchor: int = 2

    # Feature base
    tfidf_max_features: int = 5000
    tfidf_min_df: int = 5
    tfidf_ngram_range: tuple = (1, 2)
    svd_dim = 256
    svd_dim_list = [64, 128, 256, 512]
    normalize_svd = True         

    # Graph
    graph_max_gap_actor: int = 2 * 60   
    graph_max_gap_host: int  = 60            
    use_host_edges: bool = True          
    use_actor_edges: bool = True

    # RCA walk
    rca_max_nodes: int = 200
    rca_max_hops: int | None = None
    rca_max_back_seconds: int = 30 * 60
    rca_forward_slack_seconds: int = 60
    rca_priority_mode: str = "baseline"

    # Contrastive training
    proj_dim: int = 128
    batch_size: int = 512
    epochs: int = 5  
    lr: float = 1e-4  
    temperature: float = 0.07
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    seed: int = 42
    normalize_proj = True          

    # Scoring
    score_merge_w: float = 1.0

    # Evaluation @k
    ks: tuple = (5, 10, 20, 50)

cfg = ContrastiveRunConfig()

In [16]:
NOTEBOOK_DIR = Path.cwd()
PROJECT_DIR = NOTEBOOK_DIR.parent
DATASET_DIR = PROJECT_DIR / "data"
RUNS_DIR = PROJECT_DIR / "runs" / "contrastive"
RUNS_DIR.mkdir(parents=True, exist_ok=True)

def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

seed_everything(cfg.seed)

DATA_PATH = DATASET_DIR / cfg.dataset_path
assert DATA_PATH.exists(), f"Missing: {DATA_PATH}"

df = pd.read_parquet(DATA_PATH).reset_index(drop=True)
df["node_id"] = df.index.astype(int)

train_df = df[df["episode_id"] != cfg.golden_id].copy()
gold_df  = df[df["episode_id"] == cfg.golden_id].copy()

df[cfg.text_col] = df[cfg.text_col].fillna("").astype(str)
train_df[cfg.text_col] = train_df[cfg.text_col].fillna("").astype(str)
gold_df[cfg.text_col] = gold_df[cfg.text_col].fillna("").astype(str)

print("Loaded rows:", len(df))
print("Train rows:", len(train_df), "Golden rows:", len(gold_df))
print("Device:", cfg.device)

Loaded rows: 569803
Train rows: 539805 Golden rows: 29998
Device: cpu


In [17]:
# Base embeddings with TF-IDF
t0 = time.time()

tfidf = TfidfVectorizer(
    max_features=cfg.tfidf_max_features,
    min_df=cfg.tfidf_min_df,
    ngram_range=cfg.tfidf_ngram_range,
)

X_train_tfidf = tfidf.fit_transform(train_df[cfg.text_col])
X_gold_tfidf  = tfidf.transform(gold_df[cfg.text_col])

# SVD
svd = TruncatedSVD(n_components=cfg.svd_dim, random_state=cfg.seed)
Z_train = svd.fit_transform(X_train_tfidf)
Z_gold  = svd.transform(X_gold_tfidf)

#L2 normalization
if cfg.normalize_svd:
    norm = Normalizer(copy=False)
    Z_train = norm.fit_transform(Z_train)
    Z_gold  = norm.transform(Z_gold)

print("Base embeddings:", Z_train.shape, Z_gold.shape, f"({time.time()-t0:.2f}s)")

Base embeddings: (539805, 256) (29998, 256) (169.02s)


In [None]:
def mine_positive_pairs(ep: pd.DataFrame, window_sec: int, max_pos_per_anchor: int, max_anchors: int):
    ep = ep.sort_values("timestamp").reset_index(drop=True)
    n = len(ep)
    if n == 0:
        return []
    anchor_idxs = np.random.RandomState(cfg.seed).choice(n, size=min(n, max_anchors), replace=False)
    anchor_idxs = np.sort(anchor_idxs)          # events we will try to find a related partner for
    ts = ep["timestamp"].values.astype("datetime64[ns]")
    actor = ep["actor_ip_anon"].astype(str).values
    host = ep["host_anon"].astype(str).values
    stream = ep["stream"].astype(str).values
    pairs = []
    for i in anchor_idxs:
        t_i = ts[i]
        lo = np.searchsorted(ts, t_i - np.timedelta64(window_sec, "s"), side="left")
        hi = np.searchsorted(ts, t_i + np.timedelta64(window_sec, "s"), side="right")
        if hi - lo <= 1:
            continue
        cand = np.arange(lo, hi)
        cand = cand[cand != i]
        if len(cand) == 0:
            continue
        ai = actor[i]
        hi_i = host[i]
        si = stream[i]
        pos = []
        # Prefer different stream, same actor
        if ai not in ("none", "nan", ""):
            same_actor = cand[actor[cand] == ai]
            cross_stream = same_actor[stream[same_actor] != si]
            pos.extend(list(cross_stream[:max_pos_per_anchor]))
            # Fallback to same stream if needed
            if len(pos) < max_pos_per_anchor:
                same_stream = same_actor[stream[same_actor] == si]
                pos.extend(list(same_stream[:max_pos_per_anchor - len(pos)]))
        # Fallback: different stream, same host
        if len(pos) < max_pos_per_anchor and hi_i not in ("none", "nan", ""):
            same_host = cand[host[cand] == hi_i]
            cross_stream = same_host[stream[same_host] != si]
            for j in cross_stream:
                if j not in pos:
                    pos.append(int(j))
                if len(pos) >= max_pos_per_anchor:
                    break
        for j in pos:
            pairs.append((int(ep.iloc[i]["node_id"]), int(ep.iloc[j]["node_id"])))
    return pairs

t0 = time.time()
all_pairs = []
for eid, ep in train_df.groupby("episode_id"): #mining happens within an episode
    all_pairs.extend(
        mine_positive_pairs(ep, cfg.pair_window_sec, cfg.max_pos_per_anchor, cfg.max_anchors_per_episode)
    )
all_pairs = list(dict.fromkeys(all_pairs)) 
print("Mined positive pairs:", len(all_pairs), f"({time.time()-t0:.2f}s)")
assert len(all_pairs) > 0, "No positive pairs mined — increase window or check actor/host coverage."

In [None]:
train_by_node = train_df[["node_id"]].reset_index(drop=True)
node_to_idx = dict(zip(train_by_node["node_id"].values, np.arange(len(train_by_node)))) 

pair_idx = []
for a, p in all_pairs:     # a,p are node_id
    ia = node_to_idx.get(a)
    ip = node_to_idx.get(p)
    if ia is not None and ip is not None:
        pair_idx.append((ia, ip)) # This is a list of tuples with indices, no info here yet about events, just where to find them

In [None]:
class PairDataset(Dataset):
    def __init__(self, Z: np.ndarray, pairs: list[tuple[int,int]]):
        self.Z = torch.tensor(Z, dtype=torch.float32) 
        self.pairs = pairs
    def __len__(self): return len(self.pairs)
    def __getitem__(self, idx):
        i, j = self.pairs[idx]
        return self.Z[i], self.Z[j] # for the i-th pair, return the two float vectors that correspond to the anchor and its positive

ds_pairs = PairDataset(Z_train, pair_idx)
dl = DataLoader(ds_pairs, batch_size=cfg.batch_size, shuffle=True, drop_last=True) # Here we structure the pairs in mini-batches

In [None]:
class Projector(nn.Module): # Here I map my base embedding space (SVD) into a space that is optimized for the contrastive objective
    def __init__(self, in_dim: int, out_dim: int, normalize_out: bool = True):
        super().__init__()
        self.normalize_out = normalize_out
        self.net = nn.Sequential(
            nn.Linear(in_dim, in_dim), # Creates a learned mixture of the original 256 numbers, so tries different weights for the other features
                                       # and sees how the loss function says how the model behaves
            nn.ReLU(), # turns to 0 negatives
            nn.Linear(in_dim, out_dim), # This is the dimensionality reduction step
        )
    def forward(self, x):
        z = self.net(x)
        return F.normalize(z, dim=1) if self.normalize_out else z # We normalize for the similarity to be computed as cosine similarity


def info_nce_loss(z1, z2, temperature: float):
    B = z1.size(0)
    labels = torch.arange(B, device=z1.device) # For a batch of size B, the index i is the positive match for index i
    logits12 = (z1 @ z2.T) / temperature  # similarity computed with dot products, so for it to be cosine similarity we have normalized in Projector
    logits21 = (z2 @ z1.T) / temperature
    return 0.5 * (F.cross_entropy(logits12, labels) + F.cross_entropy(logits21, labels)) # For each i, make the biggest score in row i occur at column i (the true positive)

proj = Projector(cfg.svd_dim, cfg.proj_dim, normalize_out=cfg.normalize_proj).to(cfg.device)

opt = torch.optim.AdamW(proj.parameters(), lr=cfg.lr, weight_decay=1e-4) # Responsible for updating the projector’s learnable parameters (weights)
                                                                         # torch.optim.AdamW(...) is AdamW optimization algorithm
                                                                         # proj.parameters() is the weights of the linear layers
                                                                         # lr is learning rate, a smaller one gives smaller, more cautious weight updates
                                                                         # weight_decay=1e-4: "prefer simpler parameter values unless there is strong evidence to make them large.”

proj.train()
for ep in range(cfg.epochs):
    losses = []
    for x1, x2 in dl:
        x1, x2 = x1.to(cfg.device), x2.to(cfg.device)   # Here we move mini-batches to GPU, not CPU
        z1, z2 = proj(x1), proj(x2)
        loss = info_nce_loss(z1, z2, cfg.temperature)

        opt.zero_grad()  # Reset gradients to 0 to not be summing gradients across multiple batches
        loss.backward()  # Computes gradients of loss: “If I increase this weight slightly, does the loss go up or down, and by how much?”
        torch.nn.utils.clip_grad_norm_(proj.parameters(), 1.0)  # Computes the overall norm (size) of all gradients in the projector, and if it is > 1.0, it scales all gradients down so the norm becomes 1.0
        opt.step() # Optimization step, updates the weights

        losses.append(loss.item())
    print(f"Epoch {ep+1}/{cfg.epochs} loss={np.mean(losses):.4f}")

In [None]:
# Save weights
RUN_DIR = RUNS_DIR / f"run_seed{cfg.seed}_svd{cfg.svd_dim}_proj{cfg.proj_dim}_win{cfg.pair_window_sec}"
RUN_DIR.mkdir(parents=True, exist_ok=True)
torch.save(proj.state_dict(), RUN_DIR / "projector.pt")
with (RUN_DIR / "config.json").open("w") as f:
    json.dump(asdict(cfg), f, indent=2)
print("Saved:", RUN_DIR)

In [None]:
proj.eval()
with torch.no_grad():                    
    Zt = torch.tensor(Z_train, dtype=torch.float32, device=cfg.device)
    Zg = torch.tensor(Z_gold, dtype=torch.float32, device=cfg.device)
    E_train = proj(Zt).cpu().numpy()
    E_gold  = proj(Zg).cpu().numpy()

print("Projected embeddings:", E_train.shape, E_gold.shape)

In [None]:
from sklearn.ensemble import IsolationForest

iso_cl = IsolationForest(
    n_estimators=100,
    contamination="auto",
    random_state=cfg.seed,
    n_jobs=-1,
)

iso_cl.fit(E_train)  

raw = iso_cl.decision_function(E_gold)     # higher = more normal
score_cl = -raw                            # higher = more anomalous

gold_df = gold_df.copy()
gold_df["score_cl"] = score_cl

In [None]:
def add_normalized_score(df: pd.DataFrame, col: str, out_col: str = "score_norm_cl") -> pd.DataFrame:
    df = df.copy()
    s = df[col]
    mask = s.notna()
    if mask.sum() == 0:
        df[out_col] = 0.0
        return df
    s_valid = s[mask]
    s_min, s_max = float(s_valid.min()), float(s_valid.max())
    if s_max == s_min:
        df[out_col] = 0.5
        return df
    score = (s - s_min) / (s_max - s_min)
    score[~mask] = 0.0
    df[out_col] = score
    return df

gold_df = add_normalized_score(gold_df, col="score_cl", out_col="score_norm_cl")

In [None]:
gold_df.loc[gold_df.index, "score_norm_cl"] = gold_df["score_norm_cl"].values

In [None]:
gold_df.loc[golden_df.index, "score_norm"] = golden_df["score_norm"].values

In [None]:
gold_df["score_norm_merged"] = (
    cfg.score_merge_w * gold_df["score_norm_cl"] +
    (1.0 - cfg.score_merge_w) * gold_df["score_norm"]
)

In [None]:
ALERT_NODE_ID = "ALERT_NODE" 

def build_episode_graph_from_scores(df: pd.DataFrame, cfg: ContrastiveRunConfig) -> nx.DiGraph:
    df = df.sort_values("timestamp").copy()
    G = nx.DiGraph()

    for _, row in df.iterrows():
        nid = int(row["node_id"])
        G.add_node(
            nid,
            timestamp=row["timestamp"],
            stream=row["stream"],
            text=row["masked_message_cl"],
            actor_ip_anon=str(row.get("actor_ip_anon", "none")),
            host_anon=str(row.get("host_anon", "none")),
            score_cl=float(row.get("score_cl", 0.0)),
            score_norm=float(row.get("score_norm_merged", 0.0))
        )

    def add_edges(key_col: str, max_gap: int, kind: str):
        for key, group in df.groupby(key_col):
            if str(key) in ("none", "", "nan", "None"):
                continue
            group = group.sort_values("timestamp")
            prev_nid, prev_ts = None, None
            for _, r in group.iterrows():
                nid = int(r["node_id"])
                ts = r["timestamp"]
                if prev_nid is not None:
                    dt = (ts - prev_ts).total_seconds()
                    if 0 <= dt <= max_gap:
                        G.add_edge(prev_nid, nid, kind=kind, dt=float(dt))
                prev_nid, prev_ts = nid, ts

    if cfg.use_actor_edges:
        add_edges("actor_ip_anon", cfg.graph_max_gap_actor, "actor_ip_temporal")
    if cfg.use_host_edges:
        add_edges("host_anon", cfg.graph_max_gap_host, "host_temporal")

    return G

G = build_episode_graph_from_scores(gold_df, cfg)
print("Graph nodes/edges:", G.number_of_nodes(), G.number_of_edges())

In [None]:
alert_path = PROJECT_DIR / cfg.alert_csv_path
alert_df = pd.read_csv(alert_path)
alert_row = alert_df.iloc[0]

alert_ts = pd.to_datetime(alert_row["@timestamp"], utc=True)
alert_src_ip = alert_row.get("source.ip", None)
alert_name   = str(alert_row.get("kibana.alert.rule.name", "") or "")
alert_reason = str(alert_row.get("kibana.alert.reason", "") or "")
alert_desc   = str(alert_row.get("description", "") or "")

In [None]:
attached_log_ids = attach_alert_node(G, gold_df, alert_ts, alert_src_ip, alert_name, alert_reason, alert_desc)
save_json({"attached_log_ids": attached_log_ids}, RUN_DIR / "alert_attach.json")
print("Attached node_ids:", attached_log_ids)

In [None]:
assert ALERT_NODE_ID in G.nodes

succ = list(G.successors(ALERT_NODE_ID))
print("Alert successors:", succ[:10], "count=", len(succ))

attached_nid = list(G.successors(ALERT_NODE_ID))[0]

print("Outgoing from attached:", G.out_degree(attached_nid))
print("Incoming to attached:", G.in_degree(attached_nid))

anc = nx.ancestors(G, attached_nid) - {ALERT_NODE_ID}

print("Ancestors of attached:", len(anc))

anc_df = gold_df.set_index("node_id").loc[list(anc)]
print(anc_df["stream"].value_counts().head(10))

In [None]:
def rca_walk_score(G: nx.DiGraph, cfg: ContrastiveRunConfig, alert_node_id: str = ALERT_NODE_ID):
    alert_ts = G.nodes[alert_node_id]["timestamp"]

    def get_score_norm(node_id):
        if node_id == alert_node_id:
            return 1.0
        return float(G.nodes[node_id].get("score_norm", 0.0) or 0.0)

    def priority_key(score: float, dt: float):
        adt = abs(dt)

        if cfg.rca_priority_mode == "baseline":
            return (-score, adt)

        raise ValueError(f"Unknown rca_priority_mode={cfg.rca_priority_mode}")

    selected = {alert_node_id}
    visited  = {alert_node_id}
    ranked   = [alert_node_id]

    heap = []
    for succ in G.successors(alert_node_id):
        ts = G.nodes[succ]["timestamp"]
        dt = (alert_ts - ts).total_seconds()

        if dt > cfg.rca_max_back_seconds:
            continue
        if dt < -cfg.rca_forward_slack_seconds:
            continue

        score = get_score_norm(succ)
        heapq.heappush(heap, (priority_key(score, dt), 1, succ))

    while heap and len(selected) < cfg.rca_max_nodes:
        (_, hops, node_id) = heapq.heappop(heap)

        if node_id in visited:
            continue
        if cfg.rca_max_hops is not None and hops > cfg.rca_max_hops:
            continue

        visited.add(node_id)
        selected.add(node_id)
        ranked.append(node_id)

        for pred in G.predecessors(node_id):
            if pred in visited or pred == alert_node_id:
                continue

            ts = G.nodes[pred]["timestamp"]
            dt = (alert_ts - ts).total_seconds()

            if dt > cfg.rca_max_back_seconds:
                continue
            if dt < -cfg.rca_forward_slack_seconds:
                continue

            score = get_score_norm(pred)
            heapq.heappush(heap, (priority_key(score, dt), hops + 1, pred))

    subG = G.subgraph(selected).copy()
    return subG, selected, ranked

In [None]:
def rca_walk_score_with_hops(G: nx.DiGraph, cfg: ContrastiveRunConfig, alert_node_id: str = "ALERT_NODE"):
    alert_ts = G.nodes[alert_node_id]["timestamp"]

    def get_score_norm(node_id):
        if node_id == alert_node_id:
            return 1.0
        return float(G.nodes[node_id].get("score_norm", 0.0) or 0.0)

    def priority_key(score: float, dt: float):
        adt = abs(dt)
        if cfg.rca_priority_mode == "baseline":
            return (-score, adt)
        raise ValueError(f"Unknown rca_priority_mode={cfg.rca_priority_mode}")

    selected = {alert_node_id}
    visited  = {alert_node_id}
    ranked   = [alert_node_id]

    hop_of = {alert_node_id: 0}

    heap = []
    for succ in G.successors(alert_node_id):
        ts = G.nodes[succ]["timestamp"]
        dt = (alert_ts - ts).total_seconds()

        if dt > cfg.rca_max_back_seconds:
            continue
        if dt < -cfg.rca_forward_slack_seconds:
            continue

        score = get_score_norm(succ)
        heapq.heappush(heap, (priority_key(score, dt), 1, succ)) 

    while heap and len(selected) < cfg.rca_max_nodes:
        (_, hops, node_id) = heapq.heappop(heap)

        if node_id in visited:
            continue
        if cfg.rca_max_hops is not None and hops > cfg.rca_max_hops:
            continue

        visited.add(node_id)
        selected.add(node_id)
        ranked.append(node_id)
        hop_of[node_id] = hops

        for pred in G.predecessors(node_id):
            if pred in visited or pred == alert_node_id:
                continue

            ts = G.nodes[pred]["timestamp"]
            dt = (alert_ts - ts).total_seconds()

            if dt > cfg.rca_max_back_seconds:
                continue
            if dt < -cfg.rca_forward_slack_seconds:
                continue

            score = get_score_norm(pred)
            heapq.heappush(heap, (priority_key(score, dt), hops + 1, pred))

    subG = G.subgraph(selected).copy()
    return subG, selected, ranked, hop_of

In [None]:
cfg_tmp = replace(cfg, rca_max_hops=None, rca_max_nodes=5000)

In [None]:
_, selected, ranked, hop_of = rca_walk_score_with_hops(G, cfg_tmp)

hop_s = pd.Series({k:v for k,v in hop_of.items() if k != ALERT_NODE_ID})

print("Max hop (selected):", int(hop_s.max()))
print(hop_s.value_counts().sort_index().head(30)) 

In [None]:
gt_core_nodes = set(gold_df.loc[gold_df["gt_core"], "node_id"].astype(int))
gt_ext_nodes  = set(gold_df.loc[gold_df["gt_extended"], "node_id"].astype(int))

core_hops = [hop_of[n] for n in gt_core_nodes if n in hop_of]
ext_hops  = [hop_of[n] for n in gt_ext_nodes  if n in hop_of]

print("Core GT hops: max=", max(core_hops), "p90=", pd.Series(core_hops).quantile(0.90))
print("Ext  GT hops: max=", max(ext_hops),  "p90=", pd.Series(ext_hops).quantile(0.90))

In [None]:
subG, selected_nodes, rca_ranked = rca_walk_score(G, cfg, alert_node_id=ALERT_NODE_ID)
print("RCA subgraph:")
print("  Nodes:", subG.number_of_nodes())
print("  Edges:", subG.number_of_edges())

## EVALUATION

In [None]:
golden_by_id = gold_df.set_index("node_id", drop=False)

def _filter_ranked_nodes(rca_ranked: List[Any], alert_node_id: Any) -> List[int]:
    return [int(n) for n in rca_ranked if n != alert_node_id and n != "ALERT_NODE"]

def _ranked_to_items(ranked_nodes: List[int], df_by_id: pd.DataFrame, use_or_duplicates: bool, evidence_col: str="evidence_id") -> List[str]:
    if not use_or_duplicates:
        return [str(n) for n in ranked_nodes]
    ranked_items, seen = [], set()
    for n in ranked_nodes:
        eid = df_by_id.loc[n, evidence_col]
        eid = "missing_evidence" if pd.isna(eid) else str(eid)
        if eid not in seen:
            seen.add(eid)
            ranked_items.append(eid)
    return ranked_items

def _gt_items(df: pd.DataFrame, gt_col: str, use_or_duplicates: bool, evidence_col: str="evidence_id") -> set:
    gt_mask = df[gt_col].astype(bool)
    if not use_or_duplicates:
        return set(map(str, df.loc[gt_mask, "node_id"].tolist()))
    return set(map(str, df.loc[gt_mask, evidence_col].dropna().tolist()))

def _prf(returned_set: set, gt_set: set) -> Tuple[int, float, float, float]:
    tp = len(returned_set & gt_set)
    p = tp / max(1, len(returned_set))
    r = tp / max(1, len(gt_set))
    f1 = 0.0 if (p + r) == 0 else (2 * p * r / (p + r))
    return tp, p, r, f1

def _pr_at_k(ranked_items: List[str], gt_set: set, k: int) -> Tuple[int, float, float, int]:
    effective_k = min(k, len(ranked_items))
    topk = set(ranked_items[:effective_k])
    tp = len(topk & gt_set)
    p = tp / max(1, effective_k)
    r = tp / max(1, len(gt_set))
    return tp, p, r, effective_k

def _hit_at_k(ranked_items: List[str], gt_set: set, k: int) -> Tuple[int, int]:
    k_used = min(k, len(ranked_items))
    if k_used == 0:
        return 0, 0
    topk = set(ranked_items[:k_used])
    hit = 1 if len(topk & gt_set) > 0 else 0
    return hit, k_used


def evaluate_rca_episode(
    df: pd.DataFrame,
    df_by_id: pd.DataFrame,
    rca_ranked: List[Any],
    cfg: ContrastiveRunConfig,
    use_or_duplicates: bool = True,
    evidence_col: str = "evidence_id",
    compute_hit_for: str = "core",  
) -> Dict[str, Any]:

    ranked_nodes = _filter_ranked_nodes(rca_ranked, ALERT_NODE_ID)
    ranked_items = _ranked_to_items(ranked_nodes, df_by_id, use_or_duplicates, evidence_col=evidence_col)
    returned_set = set(ranked_items)

    gt_core = _gt_items(df, "gt_core", use_or_duplicates, evidence_col=evidence_col)
    gt_ext  = _gt_items(df, "gt_extended", use_or_duplicates, evidence_col=evidence_col)

    core_tp, core_p, core_r, core_f1 = _prf(returned_set, gt_core)
    ext_tp,  ext_p,  ext_r,  ext_f1  = _prf(returned_set, gt_ext)

    out = {
        "mode": "or_duplicates" if use_or_duplicates else "node_level",
        "S_nodes": len(ranked_nodes),
        "S_items": len(returned_set),
        "returned_items_total_ranked": len(ranked_items),
        "gt_core_size": len(gt_core),
        "gt_ext_size": len(gt_ext),
        "core_tp": core_tp, "core_precision": core_p, "core_recall": core_r, "core_f1": core_f1,
        "ext_tp":  ext_tp,  "ext_precision":  ext_p,  "ext_recall":  ext_r,  "ext_f1":  ext_f1,
    }

    for k in cfg.ks:
        tp, p, r, k_used = _pr_at_k(ranked_items, gt_core, k)
        out[f"core_tp@{k}"] = tp
        out[f"core_P@{k}"] = p
        out[f"core_R@{k}"] = r
        out[f"core_k_used@{k}"] = k_used

        tp, p, r, k_used = _pr_at_k(ranked_items, gt_ext, k)
        out[f"ext_tp@{k}"] = tp
        out[f"ext_P@{k}"] = p
        out[f"ext_R@{k}"] = r
        out[f"ext_k_used@{k}"] = k_used

        if compute_hit_for in ("core"):
            hit, k_used = _hit_at_k(ranked_items, gt_core, k)
            out[f"core_Hit@{k}"] = hit
            out[f"core_hit_k_used@{k}"] = k_used

        if compute_hit_for in ("ext"):
            hit, k_used = _hit_at_k(ranked_items, gt_ext, k)
            out[f"ext_Hit@{k}"] = hit
            out[f"ext_hit_k_used@{k}"] = k_used

    return out

metrics_or = evaluate_rca_episode(gold_df, golden_by_id, rca_ranked, cfg, use_or_duplicates=True, compute_hit_for="core")
metrics_node = evaluate_rca_episode(gold_df, golden_by_id, rca_ranked, cfg, use_or_duplicates=False, compute_hit_for="core")

save_json(metrics_or, RUN_DIR / "metrics_or_duplicates.json")
save_json(metrics_node, RUN_DIR / "metrics_node_level.json")
save_json({"rca_ranked": [str(x) for x in rca_ranked]}, RUN_DIR / "rca_ranked.json")

metrics_or

In [None]:
def per_stream_breakdown(gold_df: pd.DataFrame, selected_node_ids: set, gt_col: str = "gt_core") -> pd.DataFrame:
    df = gold_df.copy()
    df["selected"] = df["node_id"].isin(selected_node_ids)
    out = (
        df.groupby("stream")
        .apply(lambda g: pd.Series({
            "tp": int((g[gt_col].astype(bool) & g["selected"]).sum()),
            "fp": int((~g[gt_col].astype(bool) & g["selected"]).sum()),
            "fn": int((g[gt_col].astype(bool) & ~g["selected"]).sum()),
            "selected": int(g["selected"].sum()),
            "gt_total": int(g[gt_col].astype(bool).sum()),
            "total_logs": len(g),
        }))
    )
    out["precision"] = out["tp"] / (out["tp"] + out["fp"]).replace({0: pd.NA})
    out["recall"]    = out["tp"] / (out["tp"] + out["fn"]).replace({0: pd.NA})
    return out.fillna(0.0).sort_values(["gt_total", "selected"], ascending=False)

def get_missed_and_extra_tables(gold_df: pd.DataFrame, selected_node_ids: set, gt_col: str="gt_core", top_n: int=20):
    df = gold_df.copy()
    df["selected"] = df["node_id"].isin(selected_node_ids)
    missed = df[(df[gt_col].astype(bool)) & (~df["selected"])].copy()
    extra  = df[(~df[gt_col].astype(bool)) & (df["selected"])].copy()

    cols = ["timestamp", "stream", "masked_message_cl", "score_norm", "evidence_id", "node_id"]
    missed = missed.sort_values(["score_norm", "timestamp"], ascending=[False, True])[cols].head(top_n)
    extra  = extra.sort_values(["score_norm", "timestamp"], ascending=[False, True])[cols].head(top_n)
    return missed, extra

selected_node_ids = set([int(n) for n in rca_ranked if n != ALERT_NODE_ID and n != "ALERT_NODE"])

per_stream_core = per_stream_breakdown(gold_df, selected_node_ids, gt_col="gt_core")
per_stream_ext  = per_stream_breakdown(gold_df, selected_node_ids, gt_col="gt_extended")

missed_core, extra_non_core = get_missed_and_extra_tables(gold_df, selected_node_ids, gt_col="gt_core", top_n=20)

save_json(per_stream_core.to_dict(orient="index"), RUN_DIR / "per_stream_core.json")
save_json(per_stream_ext.to_dict(orient="index"), RUN_DIR / "per_stream_ext.json")

print("Saved diagnostics to:", RUN_DIR)

# EXPERIMENTS

## EXPERIMENT A: Connectivity/Reachability

In [None]:
def compute_reachable_by_walk(G: nx.DiGraph, cfg: ContrastiveRunConfig, alert_node_id: str = ALERT_NODE_ID) -> set[int]:
    alert_ts = G.nodes[alert_node_id]["timestamp"]

    def dt_ok(nid) -> bool:
        ts = G.nodes[nid]["timestamp"]
        dt = (alert_ts - ts).total_seconds()
        if dt > cfg.rca_max_back_seconds:
            return False
        if dt < -cfg.rca_forward_slack_seconds:
            return False
        return True

    visited = set([alert_node_id])
    reachable = set()
    q = deque()

    for succ in G.successors(alert_node_id):
        if succ == alert_node_id:
            continue
        if not dt_ok(succ):
            continue
        q.append((succ, 1))

    while q:
        nid, hops = q.popleft()

        if nid in visited:
            continue
        if cfg.rca_max_hops is not None and hops > cfg.rca_max_hops:
            continue

        visited.add(nid)
        if nid != alert_node_id:
            reachable.add(int(nid))

        for pred in G.predecessors(nid):
            if pred in visited or pred == alert_node_id:
                continue
            if not dt_ok(pred):
                continue
            q.append((pred, hops + 1))

    return reachable


def _evidence_series_with_fallback(df: pd.DataFrame) -> pd.Series:
    ev = df["evidence_id"] if "evidence_id" in df.columns else pd.Series([pd.NA] * len(df), index=df.index)
    ev = ev.where(ev.notna(), df["node_id"].apply(lambda x: f"node_{int(x)}"))
    ev = ev.astype(str)
    ev = ev.replace({"nan": "", "none": "", "None": ""})
    return ev


def node_ids_to_evidence_set(gold_df: pd.DataFrame, node_ids: set[int]) -> set[str]:
    sub = gold_df.loc[gold_df["node_id"].isin(node_ids), ["node_id", "evidence_id"]].copy()
    if sub.empty:
        return set()
    sub["evidence_id"] = _evidence_series_with_fallback(sub)
    ev_set = set(sub["evidence_id"].tolist())
    ev_set.discard("")  
    return ev_set


def gt_evidence_set(gold_df: pd.DataFrame, gt_col: str) -> set[str]:
    sub = gold_df.loc[gold_df[gt_col].astype(bool), ["node_id", "evidence_id"]].copy()
    if sub.empty:
        return set()
    sub["evidence_id"] = _evidence_series_with_fallback(sub)
    ev_set = set(sub["evidence_id"].tolist())
    ev_set.discard("")
    return ev_set


def reachable_recall_evidence(
    gold_df: pd.DataFrame,
    reachable_node_ids: set[int],
    gt_col: str,
) -> tuple[int, int, float]:
    gt_ev = gt_evidence_set(gold_df, gt_col)
    if len(gt_ev) == 0:
        return 0, 0, 0.0
    reachable_ev = node_ids_to_evidence_set(gold_df, reachable_node_ids)
    hit = len(gt_ev & reachable_ev)
    return hit, len(gt_ev), hit / len(gt_ev)

def build_graph_and_attach(
    gold_df: pd.DataFrame,
    cfg_variant: ContrastiveRunConfig,
    alert_ts: pd.Timestamp,
    alert_src_ip,
    alert_name: str,
    alert_reason: str,
    alert_desc: str,
) -> nx.DiGraph:
    Gv = build_episode_graph_from_scores(gold_df, cfg_variant)
    attach_alert_node(Gv, gold_df, alert_ts, alert_src_ip, alert_name, alert_reason, alert_desc)
    return Gv

def run_reachability_ablation_A_evidence(
    gold_df: pd.DataFrame,
    cfg_base: ContrastiveRunConfig,
    alert_ts: pd.Timestamp,
    alert_src_ip,
    alert_name: str,
    alert_reason: str,
    alert_desc: str,
    actor_gaps=(30, 60, 120, 300, 600), 
    host_gaps=(30, 60, 120, 300),        
):
    rows = []

    core_gt_ev = gt_evidence_set(gold_df, "gt_core")
    ext_gt_ev  = gt_evidence_set(gold_df, "gt_extended")

    edge_variants = [
        ("both_edges", True, True,  cfg_base.graph_max_gap_actor, cfg_base.graph_max_gap_host),
        ("actor_only", True, False, cfg_base.graph_max_gap_actor, cfg_base.graph_max_gap_host),
        ("host_only",  False, True, cfg_base.graph_max_gap_actor, cfg_base.graph_max_gap_host),
        ("no_edges",   False, False,cfg_base.graph_max_gap_actor, cfg_base.graph_max_gap_host),
    ]

    for name, ua, uh, a_gap, h_gap in edge_variants:
        cfgv = replace(
            cfg_base,
            use_actor_edges=ua,
            use_host_edges=uh,
            graph_max_gap_actor=a_gap,
            graph_max_gap_host=h_gap,
        )

        Gv = build_graph_and_attach(gold_df, cfgv, alert_ts, alert_src_ip, alert_name, alert_reason, alert_desc)
        reachable_nodes = compute_reachable_by_walk(Gv, cfgv, ALERT_NODE_ID)

        reachable_ev = node_ids_to_evidence_set(gold_df, reachable_nodes)

        core_hit, core_tot, rr_core = reachable_recall_evidence(gold_df, reachable_nodes, "gt_core")
        ext_hit,  ext_tot,  rr_ext  = reachable_recall_evidence(gold_df, reachable_nodes, "gt_extended")

        rows.append({
            "ablation": "edge_type",
            "variant": name,
            "use_actor_edges": ua,
            "use_host_edges": uh,
            "actor_gap_s": int(a_gap),
            "host_gap_s": int(h_gap),
            "nodes": Gv.number_of_nodes(),
            "edges": Gv.number_of_edges(),
            "reachable_nodes": int(len(reachable_nodes)),
            "reachable_evidence": int(len(reachable_ev)),
            "core_gt_evidence_total": int(len(core_gt_ev)),
            "ext_gt_evidence_total": int(len(ext_gt_ev)),
            "ReachRec_core_evidence": rr_core,
            "core_reachable_evidence": f"{core_hit}/{core_tot}",
            "ReachRec_ext_evidence": rr_ext,
            "ext_reachable_evidence": f"{ext_hit}/{ext_tot}",
        })

    for a_gap in actor_gaps:
        cfgv = replace(
            cfg_base,
            use_actor_edges=True,
            use_host_edges=True,
            graph_max_gap_actor=int(a_gap),
            graph_max_gap_host=int(cfg_base.graph_max_gap_host),
        )

        Gv = build_graph_and_attach(gold_df, cfgv, alert_ts, alert_src_ip, alert_name, alert_reason, alert_desc)
        reachable_nodes = compute_reachable_by_walk(Gv, cfgv, ALERT_NODE_ID)
        reachable_ev = node_ids_to_evidence_set(gold_df, reachable_nodes)

        core_hit, core_tot, rr_core = reachable_recall_evidence(gold_df, reachable_nodes, "gt_core")
        ext_hit,  ext_tot,  rr_ext  = reachable_recall_evidence(gold_df, reachable_nodes, "gt_extended")

        rows.append({
            "ablation": "actor_gap_sweep",
            "variant": f"actor_gap={a_gap}",
            "use_actor_edges": True,
            "use_host_edges": True,
            "actor_gap_s": int(a_gap),
            "host_gap_s": int(cfg_base.graph_max_gap_host),
            "nodes": Gv.number_of_nodes(),
            "edges": Gv.number_of_edges(),
            "reachable_nodes": int(len(reachable_nodes)),
            "reachable_evidence": int(len(reachable_ev)),
            "core_gt_evidence_total": int(len(core_gt_ev)),
            "ext_gt_evidence_total": int(len(ext_gt_ev)),
            "ReachRec_core_evidence": rr_core,
            "core_reachable_evidence": f"{core_hit}/{core_tot}",
            "ReachRec_ext_evidence": rr_ext,
            "ext_reachable_evidence": f"{ext_hit}/{ext_tot}",
        })

    for h_gap in host_gaps:
        cfgv = replace(
            cfg_base,
            use_actor_edges=True,
            use_host_edges=True,
            graph_max_gap_actor=int(cfg_base.graph_max_gap_actor),
            graph_max_gap_host=int(h_gap),
        )

        Gv = build_graph_and_attach(gold_df, cfgv, alert_ts, alert_src_ip, alert_name, alert_reason, alert_desc)
        reachable_nodes = compute_reachable_by_walk(Gv, cfgv, ALERT_NODE_ID)
        reachable_ev = node_ids_to_evidence_set(gold_df, reachable_nodes)

        core_hit, core_tot, rr_core = reachable_recall_evidence(gold_df, reachable_nodes, "gt_core")
        ext_hit,  ext_tot,  rr_ext  = reachable_recall_evidence(gold_df, reachable_nodes, "gt_extended")

        rows.append({
            "ablation": "host_gap_sweep",
            "variant": f"host_gap={h_gap}",
            "use_actor_edges": True,
            "use_host_edges": True,
            "actor_gap_s": int(cfg_base.graph_max_gap_actor),
            "host_gap_s": int(h_gap),
            "nodes": Gv.number_of_nodes(),
            "edges": Gv.number_of_edges(),
            "reachable_nodes": int(len(reachable_nodes)),
            "reachable_evidence": int(len(reachable_ev)),
            "core_gt_evidence_total": int(len(core_gt_ev)),
            "ext_gt_evidence_total": int(len(ext_gt_ev)),
            "ReachRec_core_evidence": rr_core,
            "core_reachable_evidence": f"{core_hit}/{core_tot}",
            "ReachRec_ext_evidence": rr_ext,
            "ext_reachable_evidence": f"{ext_hit}/{ext_tot}",
        })

    out = pd.DataFrame(rows)

    out = out.sort_values(
        ["ablation", "ReachRec_core_evidence", "ReachRec_ext_evidence", "reachable_evidence", "edges"],
        ascending=[True, False, False, False, False],
    )
    return out

ablation_A_evidence = run_reachability_ablation_A_evidence(
    gold_df=gold_df,
    cfg_base=cfg,
    alert_ts=alert_ts,
    alert_src_ip=alert_src_ip,
    alert_name=alert_name,
    alert_reason=alert_reason,
    alert_desc=alert_desc,
)

display(ablation_A_evidence)

In [None]:
ablation_A_path = RUN_DIR / "ablation_A_reachability.csv"
ablation_A_evidence.to_csv(ablation_A_path, index=False)
print("Saved:", ablation_A_path)

# EXPERIMENT B: Score Blending

In [None]:
weights = [0.0, 0.25, 0.5, 0.75, 1.0]
results = []

for w in weights:
    gold_df["score_norm_merged"] = w*gold_df["score_norm_cl"] + (1-w)*gold_df["score_norm"]
    G = build_episode_graph_from_scores(gold_df, cfg) 
    attach_alert_node(G, gold_df, alert_ts, alert_src_ip, alert_name, alert_reason, alert_desc)
    subG, selected, rca_ranked = rca_walk_score(G, cfg)
    metrics = evaluate_rca_episode(gold_df, gold_df.set_index("node_id", drop=False), rca_ranked, cfg, use_or_duplicates=True)
    metrics["w"] = w
    results.append(metrics)

pd.DataFrame(results)[["w","core_f1","core_P@10","core_R@10","ext_f1","ext_P@10","ext_R@10","ext_R@20"]]

# EXPERIMENT C: Walk budget + hop limit

In [None]:
def _run_walk_eval(
    G,
    gold_df,
    cfgv,
    alert_node_id=ALERT_NODE_ID,
    mode_name="contrastive",
    use_or_duplicates=True,
    compute_hop_diagnostics=True,
):
    hop_of = None

    if compute_hop_diagnostics and ("rca_walk_score_with_hops" in globals()):
        subGv, selected_v, ranked_v, hop_of = rca_walk_score_with_hops(
            G, cfgv, alert_node_id=alert_node_id
        )
    else:
        subGv, selected_v, ranked_v = rca_walk_score(G, cfgv, alert_node_id=alert_node_id)

    m = evaluate_rca_episode(
        gold_df,
        golden_by_id,
        ranked_v,
        cfgv,
        use_or_duplicates=use_or_duplicates,
    )

    row = {
        "model": mode_name,
        "rca_max_nodes": int(cfgv.rca_max_nodes),
        "rca_max_hops": ("None" if cfgv.rca_max_hops is None else int(cfgv.rca_max_hops)),
        "selected_nodes": len([n for n in selected_v if n != alert_node_id]),
    }
    row.update(m)

    # Hop diagnostics 
    if hop_of is not None:
        sel_wo_alert = [n for n in selected_v if n != alert_node_id]
        hops_vals = [hop_of.get(n) for n in sel_wo_alert if hop_of.get(n) is not None]
        if len(hops_vals) > 0:
            row["max_hop_selected"] = int(np.max(hops_vals))
            row["p90_hop_selected"] = float(np.quantile(hops_vals, 0.90))
            row["mean_hop_selected"] = float(np.mean(hops_vals))
        else:
            row["max_hop_selected"] = 0
            row["p90_hop_selected"] = 0.0
            row["mean_hop_selected"] = 0.0

    return row


def _sort_ablation(df: pd.DataFrame) -> pd.DataFrame:
    sort_cols = []
    for c in ["core_f1", "ext_f1", "core_P@10", "ext_P@10"]:
        if c in df.columns:
            sort_cols.append(c)
    for c in ["S_items", "returned_items_total_ranked", "S_nodes", "selected_nodes"]:
        if c in df.columns:
            sort_cols.append(c)

    ascending = []
    for c in sort_cols:
        if c in ("S_items", "returned_items_total_ranked", "S_nodes", "selected_nodes"):
            ascending.append(True)
        else:
            ascending.append(False)

    if sort_cols:
        return df.sort_values(sort_cols, ascending=ascending)
    return df


# Hop sensitivity 
def run_hop_sensitivity(
    G,
    gold_df,
    cfg_base,
    alert_node_id=ALERT_NODE_ID,
    fixed_max_nodes=5000,                    
    max_hops_grid=(1, 2, 3, 5, 10, 25, 50, 100, 150, None),
    mode_name="contrastive",
    compute_hop_diagnostics=True,
):
    rows = []
    for max_hops in max_hops_grid:
        cfgv = replace(
            cfg_base,
            rca_max_nodes=int(fixed_max_nodes),
            rca_max_hops=max_hops,
        )
        row = _run_walk_eval(
            G=G,
            gold_df=gold_df,
            cfgv=cfgv,
            alert_node_id=alert_node_id,
            mode_name=mode_name,
            compute_hop_diagnostics=compute_hop_diagnostics,
        )
        row["ablation"] = "hop_sensitivity"
        rows.append(row)

    out = pd.DataFrame(rows)

    def _hop_key(x):
        return 10**9 if x == "None" else int(x)

    if "rca_max_hops" in out.columns:
        out = out.sort_values("rca_max_hops", key=lambda s: s.map(_hop_key))

    return out


# Node budget sensitivity 
def run_budget_sensitivity(
    G,
    gold_df,
    cfg_base,
    alert_node_id=ALERT_NODE_ID,
    max_nodes_grid=(25, 50, 100, 150, 200),
    fixed_max_hops=None,                   
    mode_name="contrastive",
    compute_hop_diagnostics=True,
):
    rows = []
    for max_nodes in max_nodes_grid:
        cfgv = replace(
            cfg_base,
            rca_max_nodes=int(max_nodes),
            rca_max_hops=fixed_max_hops,
        )
        row = _run_walk_eval(
            G=G,
            gold_df=gold_df,
            cfgv=cfgv,
            alert_node_id=alert_node_id,
            mode_name=mode_name,
            compute_hop_diagnostics=compute_hop_diagnostics,
        )
        row["ablation"] = "budget_sensitivity"
        rows.append(row)

    out = pd.DataFrame(rows)
    if "rca_max_nodes" in out.columns:
        out = out.sort_values("rca_max_nodes")

    return out


hop_sens_contrastive = run_hop_sensitivity(
    G=G,
    gold_df=gold_df,
    cfg_base=cfg,
    alert_node_id=ALERT_NODE_ID,
    fixed_max_nodes=5000,
    max_hops_grid=(1, 2, 3, 5, 10, 25, 50, 100, 150, None),
    mode_name="contrastive",
    compute_hop_diagnostics=True,
)

budget_sens_contrastive = run_budget_sensitivity(
    G=G,
    gold_df=gold_df,
    cfg_base=cfg,
    alert_node_id=ALERT_NODE_ID,
    max_nodes_grid=(25, 50, 100, 150, 200),
    fixed_max_hops=None,   
    mode_name="contrastive",
    compute_hop_diagnostics=True,
)

display(_sort_ablation(hop_sens_contrastive))
display(_sort_ablation(budget_sens_contrastive))

# EXPERIMENT D: SVD DIMENSIONS

In [None]:
def run_contrastive_once(
    svd_dim: int,
    normalize_svd: bool = True,
    normalize_proj: bool = True,
) -> dict:

    seed_everything(cfg.seed)

    tfidf = TfidfVectorizer(
        max_features=cfg.tfidf_max_features,
        min_df=cfg.tfidf_min_df,
        ngram_range=cfg.tfidf_ngram_range,
    )
    X_train_tfidf = tfidf.fit_transform(train_df[cfg.text_col])
    X_gold_tfidf  = tfidf.transform(gold_df[cfg.text_col])

    svd = TruncatedSVD(n_components=svd_dim, random_state=cfg.seed)
    Z_train = svd.fit_transform(X_train_tfidf)
    Z_gold  = svd.transform(X_gold_tfidf)

    if normalize_svd:
        norm = Normalizer(copy=False)
        Z_train = norm.fit_transform(Z_train)
        Z_gold  = norm.transform(Z_gold)

    class PairDataset(Dataset):
        def __init__(self, Z: np.ndarray, pairs: list[tuple[int,int]]):
            self.Z = torch.tensor(Z, dtype=torch.float32)
            self.pairs = pairs
        def __len__(self): return len(self.pairs)
        def __getitem__(self, idx):
            i, j = self.pairs[idx]
            return self.Z[i], self.Z[j]

    ds_pairs = PairDataset(Z_train, pair_idx)
    dl = DataLoader(ds_pairs, batch_size=cfg.batch_size, shuffle=True, drop_last=True)

    class Projector(nn.Module):
        def __init__(self, in_dim: int, out_dim: int, normalize_out: bool = True):
            super().__init__()
            self.normalize_out = normalize_out
            self.net = nn.Sequential(
                nn.Linear(in_dim, in_dim),
                nn.ReLU(),
                nn.Linear(in_dim, out_dim),
            )
        def forward(self, x):
            z = self.net(x)
            return F.normalize(z, dim=1) if self.normalize_out else z

    def info_nce_loss(z1, z2, temperature: float):
        B = z1.size(0)
        labels = torch.arange(B, device=z1.device)
        logits12 = (z1 @ z2.T) / temperature
        logits21 = (z2 @ z1.T) / temperature
        return 0.5 * (F.cross_entropy(logits12, labels) + F.cross_entropy(logits21, labels))

    proj = Projector(svd_dim, cfg.proj_dim, normalize_out=normalize_proj).to(cfg.device)
    opt = torch.optim.AdamW(proj.parameters(), lr=cfg.lr, weight_decay=1e-4)

    proj.train()
    for ep in range(cfg.epochs):
        losses = []
        for x1, x2 in dl:
            x1, x2 = x1.to(cfg.device), x2.to(cfg.device)
            z1, z2 = proj(x1), proj(x2)
            loss = info_nce_loss(z1, z2, cfg.temperature)

            opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(proj.parameters(), 1.0)
            opt.step()

            losses.append(loss.item())

    proj.eval()
    with torch.no_grad():
        Zt = torch.tensor(Z_train, dtype=torch.float32, device=cfg.device)
        Zg = torch.tensor(Z_gold, dtype=torch.float32, device=cfg.device)
        E_train = proj(Zt).cpu().numpy()
        E_gold  = proj(Zg).cpu().numpy()

    iso_cl = IsolationForest(
        n_estimators=100,
        contamination="auto",
        random_state=cfg.seed,
        n_jobs=-1,
    )
    iso_cl.fit(E_train)
    raw = iso_cl.decision_function(E_gold)   # higher = more normal
    score_cl = -raw                          # higher = more anomalous

    gold_local = gold_df.copy()
    gold_local["score_cl"] = score_cl

    gold_local = add_normalized_score(gold_local, col="score_cl", out_col="score_norm_cl")

    gold_local["score_norm_merged"] = (
        cfg.score_merge_w * gold_local["score_norm_cl"] +
        (1.0 - cfg.score_merge_w) * gold_local["score_norm"]
    )

    G_local = build_episode_graph_from_scores(gold_local, cfg)

    alert_path = PROJECT_DIR / cfg.alert_csv_path
    alert_df_local = pd.read_csv(alert_path)
    alert_row = alert_df_local.iloc[0]

    alert_ts = pd.to_datetime(alert_row["@timestamp"], utc=True)
    alert_src_ip = alert_row.get("source.ip", None)
    alert_name   = str(alert_row.get("kibana.alert.rule.name", "") or "")
    alert_reason = str(alert_row.get("kibana.alert.reason", "") or "")
    alert_desc   = str(alert_row.get("description", "") or "")

    _ = attach_alert_node(G_local, gold_local, alert_ts, alert_src_ip, alert_name, alert_reason, alert_desc)

    subG, selected_nodes, rca_ranked = rca_walk_score(G_local, cfg, alert_node_id=ALERT_NODE_ID)

    df_by_id = gold_local.set_index("node_id", drop=False)
    metrics_or = evaluate_rca_episode(
        gold_local, df_by_id, rca_ranked, cfg,
        use_or_duplicates=True,
        compute_hit_for="core"  
    )

    metrics_or["svd_dim"] = svd_dim
    metrics_or["normalize_svd"] = bool(normalize_svd)
    metrics_or["normalize_proj"] = bool(normalize_proj)

    return metrics_or

In [None]:
svd_dim_list = getattr(cfg, "svd_dim_list", [64, 128, 256, 512])

rows = []
t0 = time.time()

for d in svd_dim_list:
    print(f"\n=== Running svd_dim={d} (normalize_svd=True, normalize_proj={cfg.normalize_proj}) ===")
    m = run_contrastive_once(
        svd_dim=d,
        normalize_svd=True,
        normalize_proj=cfg.normalize_proj,
    )
    rows.append({
        "svd_dim": d,
        "normalize_svd": m["normalize_svd"],
        "normalize_proj": m["normalize_proj"],
        "S_items": m["S_items"],
        "core_f1": m["core_f1"],
        "ext_f1": m["ext_f1"],
        "core_P@10": m.get("core_P@10", None),
        "ext_P@10": m.get("ext_P@10", None),
        "core_Hit@10": m.get("core_Hit@10", None),
    })

df_dim = pd.DataFrame(rows).sort_values("svd_dim").reset_index(drop=True)
print(f"\nDone. Total time: {(time.time()-t0)/60:.1f} min")
df_dim

# EXPERIMENT E: Post-SVD Normalization

In [None]:
rows = []
t0 = time.time()

for norm_flag in [True, False]:
    print(f"\n=== Running normalize_svd={norm_flag} (svd_dim=256, normalize_proj={cfg.normalize_proj}) ===")
    m = run_contrastive_once(
        svd_dim=256,
        normalize_svd=norm_flag,
        normalize_proj=cfg.normalize_proj,  
    )
    rows.append({
        "svd_dim": m["svd_dim"],
        "normalize_svd": m["normalize_svd"],
        "normalize_proj": m["normalize_proj"],
        "S_items": m["S_items"],
        "core_f1": m["core_f1"],
        "ext_f1": m["ext_f1"],
        "core_P@10": m.get("core_P@10", None),
        "ext_P@10": m.get("ext_P@10", None),
        "core_Hit@10": m.get("core_Hit@10", None),
    })

df_norm = pd.DataFrame(rows).sort_values(["svd_dim", "normalize_svd"], ascending=[True, False]).reset_index(drop=True)
print(f"\nDone. Total time: {(time.time()-t0)/60:.1f} min")
df_norm

# EXPERIMENT F: Post-Projection Head Normalization

In [None]:
rows = []
for norm_svd in [True, False]:
    for norm_proj in [True, False]:
        print(f"\n=== svd_norm={norm_svd}, proj_norm={norm_proj} (svd_dim=256) ===")
        m = run_contrastive_once(
            svd_dim=256,
            normalize_svd=norm_svd,
            normalize_proj=norm_proj,
        )
        rows.append({
            "normalize_svd": m["normalize_svd"],
            "normalize_proj": m["normalize_proj"],
            "S_items": m["S_items"],
            "core_f1": m["core_f1"],
            "ext_f1": m["ext_f1"],
            "core_P@10": m.get("core_P@10", None),
            "ext_P@10": m.get("ext_P@10", None),
            "core_Hit@10": m.get("core_Hit@10", None),
        })

df_norm_2x2 = pd.DataFrame(rows).sort_values(["normalize_svd", "normalize_proj"], ascending=[False, False]).reset_index(drop=True)
df_norm_2x2

# EXTRA METRIC: AUPRC

In [None]:
from sklearn.metrics import average_precision_score

def auprc_over_reachable_evidence(df, reachable_nodes, score_col, gt_col):
    df = gold_df.copy()
    df["evidence_id"] = _evidence_series_with_fallback(df)

    reachable_ev = node_ids_to_evidence_set(df, reachable_nodes)

    # restrict to reachable evidence items
    df = df[df["evidence_id"].isin(reachable_ev)][["evidence_id", score_col, gt_col]].copy()

    # collapse duplicates to evidence-level:
    ev = df.groupby("evidence_id").agg(
        y_score=(score_col, "max"),
        y_true=(gt_col, "max"),
    )

    y_true = ev["y_true"].astype(int).to_numpy()
    y_score = ev["y_score"].astype(float).to_numpy()

    if y_true.sum() == 0 or y_true.sum() == len(y_true):
        return None, len(ev), int(y_true.sum())

    ap = average_precision_score(y_true, y_score)
    return ap, len(ev), int(y_true.sum())

In [None]:
# Baseline
ap_b_core, pool, pos = auprc_over_reachable_evidence(
    df=golden_df,
    reachable_nodes=reachable_nodes,
    score_col="score_norm",
    gt_col="gt_core",
)
ap_b_ext, _, _ = auprc_over_reachable_evidence(
    df=golden_df,
    reachable_nodes=reachable_nodes,
    score_col="score_norm",
    gt_col="gt_extended",
)

# Contrastive
ap_c_core, _, _ = auprc_over_reachable_evidence(
    df=gold_df,
    reachable_nodes=reachable_nodes,
    score_col="score_norm_cl",   # <- change to your real column
    gt_col="gt_core",
)
ap_c_ext, _, _ = auprc_over_reachable_evidence(
    df=gold_df,
    reachable_nodes=reachable_nodes,
    score_col="score_norm_cl",
    gt_col="gt_extended",
)

print(f"Baseline AUPRC core/ext:     {ap_b_core:.4f} / {ap_b_ext:.4f}   pool={pool} pos={pos}")
print(f"Contrastive AUPRC core/ext:  {ap_c_core:.4f} / {ap_c_ext:.4f}")


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, average_precision_score

def pr_curve_over_reachable_evidence(
    golden_df,
    reachable_node_ids,
    score_col: str,
    gt_col: str,
):
    df = gold_df.copy()
    df["evidence_key"] = _evidence_series_with_fallback(df)

    df = df[df["node_id"].isin(reachable_node_ids)].copy()
    grp = df.groupby("evidence_key", sort=False)
    y_true = grp[gt_col].max().astype(int).to_numpy()
    y_score = grp[score_col].max().to_numpy()


    precision, recall, _ = precision_recall_curve(y_true, y_score)
    ap = average_precision_score(y_true, y_score)
    prevalence = float(y_true.mean())
    return precision, recall, ap, prevalence, len(y_true), int(y_true.sum())

def plot_pr(ax, title, pr_b, pr_c):
    (p_b, r_b, ap_b, prev_b, pool_b, pos_b) = pr_b
    (p_c, r_c, ap_c, prev_c, pool_c, pos_c) = pr_c

    ax.plot(r_b, p_b, label=f"Baseline (AP={ap_b:.4f})")
    ax.plot(r_c, p_c, label=f"Contrastive (AP={ap_c:.4f})")
    ax.hlines(prev_b, 0, 1, linestyles="--", label=f"Random (p={prev_b:.4f})")

    ax.set_title(title)
    ax.set_xlabel("Recall")
    ax.set_ylabel("Precision")
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.legend()
    axes[0].set_ylim(0, 0.10)
    axes[1].set_ylim(0, 0.10)

pr_b_core = pr_curve_over_reachable_evidence(golden_df, reachable_nodes, "score_norm",    "gt_core")
pr_c_core = pr_curve_over_reachable_evidence(gold_df, reachable_nodes, "score_norm_cl", "gt_core")

pr_b_ext  = pr_curve_over_reachable_evidence(golden_df, reachable_nodes, "score_norm",    "gt_extended")
pr_c_ext  = pr_curve_over_reachable_evidence(gold_df, reachable_nodes, "score_norm_cl", "gt_extended")

fig, axes = plt.subplots(1, 2, figsize=(12, 4))
plot_pr(axes[0], "PR curve over reachable evidence (core GT)", pr_b_core, pr_c_core)
plot_pr(axes[1], "PR curve over reachable evidence (extended GT)", pr_b_ext, pr_c_ext)
plt.tight_layout()
plt.show()