In [1]:
# make_embeddings_from_tokenized.py
# Produces embeddings.npy (row-aligned) from tokenized_logs.pt
# Requires: torch, transformers, numpy

import os
import torch
import numpy as np
from transformers import AutoModel
from torch.utils.data import DataLoader, TensorDataset

TOKENIZED_PT = "tokenized_logs.pt"   # your file
EMBED_OUT = "embeddings.npy"
MODEL_NAME = "distilroberta-base"
BATCH_SIZE = 64    # lower this if you run out of GPU memory
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
PRINT_EVERY = 10

def main():
    assert os.path.exists(TOKENIZED_PT), f"tokenized file not found: {TOKENIZED_PT}"
    print(f"Loading tokenized tensors from {TOKENIZED_PT} ...")
    data = torch.load(TOKENIZED_PT, weights_only=False)

    # Expect keys input_ids, attention_mask (and optionally labels)
    if "input_ids" not in data or "attention_mask" not in data:
        raise KeyError("tokenized file must contain 'input_ids' and 'attention_mask' tensors")

    input_ids = data["input_ids"]
    attention_mask = data["attention_mask"]

    # convert to torch tensors on cpu (DataLoader will move to device)
    if not isinstance(input_ids, torch.Tensor):
        input_ids = torch.tensor(input_ids)
    if not isinstance(attention_mask, torch.Tensor):
        attention_mask = torch.tensor(attention_mask)

    n_rows = input_ids.size(0)
    print(f"Found {n_rows} rows, seq_len={input_ids.size(1)}")

    # DataLoader
    ds = TensorDataset(input_ids, attention_mask)
    loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False)

    # Load model
    print(f"Loading encoder model '{MODEL_NAME}' to device={DEVICE} ...")
    model = AutoModel.from_pretrained(MODEL_NAME)
    model.to(DEVICE)
    model.eval()

    # We'll collect embeddings in a list then stack
    embeddings = []
    with torch.no_grad():
        for i, batch in enumerate(loader):
            batch_input_ids = batch[0].to(DEVICE)
            batch_attention_mask = batch[1].to(DEVICE)

            outputs = model(input_ids=batch_input_ids, attention_mask=batch_attention_mask)
            # last_hidden_state shape: (batch, seq_len, hidden)
            last_hidden = outputs.last_hidden_state
            # CLS token (DistilRoBERTa uses first token) -> index 0
            cls_emb = last_hidden[:, 0, :].detach().cpu().numpy()  # shape (batch, hidden)
            embeddings.append(cls_emb)

            if (i + 1) % PRINT_EVERY == 0:
                print(f"Processed {min((i+1)*BATCH_SIZE, n_rows)}/{n_rows} rows")

    embeddings = np.vstack(embeddings)  # shape (n_rows, hidden)
    assert embeddings.shape[0] == n_rows, f"Mismatch rows: {embeddings.shape[0]} vs {n_rows}"
    print(f"Embeddings shape: {embeddings.shape}")

    # Optionally L2-normalize rows (your PPO script normalizes again; do if you prefer)
    # norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
    # norms[norms == 0] = 1.0
    # embeddings = embeddings / norms

    # Save .npy
    np.save(EMBED_OUT, embeddings)
    print(f"Saved embeddings to: {EMBED_OUT}")

if __name__ == "__main__":
    main()


Loading tokenized tensors from tokenized_logs.pt ...
Found 7000 rows, seq_len=128
Loading encoder model 'distilroberta-base' to device=cuda ...
Processed 640/7000 rows
Processed 1280/7000 rows
Processed 1920/7000 rows
Processed 2560/7000 rows
Processed 3200/7000 rows
Processed 3840/7000 rows
Processed 4480/7000 rows
Processed 5120/7000 rows
Processed 5760/7000 rows
Processed 6400/7000 rows
Processed 7000/7000 rows
Embeddings shape: (7000, 768)
Saved embeddings to: embeddings.npy


In [2]:
#!/usr/bin/env python3
"""
make_observations_22d.py

Produces observations_22d.npy = [PCA20 components (scaled)] + [recon_scaled] + [rule_scaled]
Inputs:
  - embeddings.npy               (shape N x D)
  - synthetic_nginx_logs.csv     (must contain query, uri; optional: rule_hit_count, rule_severity_max, label)
  - reconstruction_errors.csv    (must contain 'index' and 'reconstruction_error')

Outputs:
  - observations_22d.npy        (N x 22)
  - pca_model.npz               (contains PCA components & std scaler params) (optional)
"""

import os
import numpy as np
import pandas as pd
import re
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler, MinMaxScaler
import joblib

# -------- CONFIG --------
EMBED_NPY = "embeddings.npy"
LOG_CSV = "synthetic_nginx_logs.csv"
ERR_CSV = "reconstruction_errors.csv"

OUT_OBS = "observations_22d.npy"
OUT_PCA_PICKLE = "pca_and_scaler.pkl"   # saves PCA + StandardScaler objects for future inference
PCA_DIM = 30
WHITEN = False   # if True, PCA will whiten (not recommended usually)
SEED = 42

# Pattern groups for OWASP-like detection
SQLI_PATTERNS = [
    r"(?i)\b(or|and)\b\s+\d+=\d+",
    r"(?i)union\s+select",
    r"(?i)select\s+.+\s+from",
    r"(?i)drop\s+table",
    r"(?i)--\s*$",
    r"(?i)1=1",
]
XSS_PATTERNS = [r"(?i)<script.*?>", r"(?i)onerror=|onmouseover=|<img\s+src", r"(?i)alert\("]
LFI_PATTERNS = [r"\.\./", r"etc/passwd"]
CMD_PATTERNS = [r";\s*rm\s+-rf", r"\|\s*cat\s+/etc/passwd"]
COMPILED_PATTERNS = [(SQLI_PATTERNS, 2.0), (XSS_PATTERNS, 1.8), (LFI_PATTERNS, 1.5), (CMD_PATTERNS, 2.0)]
PATTERN_MAX_SCORE = 5.0

def pattern_score(text: str) -> float:
    if not isinstance(text, str) or text.strip() == "":
        return 0.0
    t = text.lower()
    score = 0.0
    for pats, w in COMPILED_PATTERNS:
        for p in pats:
            if re.search(p, t):
                score += w
                break
    return float(min(score, PATTERN_MAX_SCORE))

def l2_norm_rows(x: np.ndarray) -> np.ndarray:
    norms = np.linalg.norm(x, axis=1, keepdims=True)
    norms[norms == 0] = 1.0
    return x / norms

def main():
    # checks
    for f in (EMBED_NPY, LOG_CSV, ERR_CSV):
        if not os.path.exists(f):
            raise FileNotFoundError(f"Required file not found: {f}")

    # load
    print("[*] Loading embeddings...")
    embeddings = np.load(EMBED_NPY)    # shape (N, D)
    N, D = embeddings.shape
    print(f"    embeddings shape: {embeddings.shape}")

    print("[*] Loading logs CSV...")
    df = pd.read_csv(LOG_CSV, dtype=str).fillna("")
    if len(df) != N:
        # Try to proceed but warn
        print(f"Warning: CSV rows ({len(df)}) != embeddings rows ({N}). Will attempt to align by index and truncate/pad as needed.")
    # Ensure needed columns exist
    if "query" not in df.columns:
        df["query"] = ""
    if "uri" not in df.columns:
        df["uri"] = ""
    # numeric rule fields optional
    if "rule_hit_count" not in df.columns:
        df["rule_hit_count"] = "0"
    if "rule_severity_max" not in df.columns:
        df["rule_severity_max"] = "0"

    print("[*] Loading reconstruction errors CSV...")
    df_errs = pd.read_csv(ERR_CSV).fillna(0.0)
    if "index" not in df_errs.columns or "reconstruction_error" not in df_errs.columns:
        raise ValueError("reconstruction_errors.csv must contain 'index' and 'reconstruction_error' columns")

    # Align reconstruction_error to df rows by index if possible, otherwise use position
    recon_series = df_errs.set_index("index")["reconstruction_error"]
    if len(recon_series) >= len(df):
        df["reconstruction_error"] = recon_series.values[:len(df)]
    else:
        # join by index (if df has standard 0..N-1)
        try:
            df = df.join(recon_series, how="left")
            df["reconstruction_error"] = df["reconstruction_error"].fillna(0.0)
        except Exception:
            # last resort: pad with zeros or truncate
            recon_vals = recon_series.values
            if len(recon_vals) < len(df):
                recon_full = np.zeros(len(df), dtype=float)
                recon_full[:len(recon_vals)] = recon_vals
                df["reconstruction_error"] = recon_full
            else:
                df["reconstruction_error"] = recon_vals[:len(df)]

    # If df has more/less rows than embeddings, align to min length
    M = min(len(df), N)
    if M != N:
        print(f"Adjusting to min rows: {M} (min of embeddings and CSV)")
        embeddings = embeddings[:M]
        df = df.iloc[:M].reset_index(drop=True)
        N = M

    # 1) L2-normalize embeddings
    emb_norm = l2_norm_rows(embeddings.astype(np.float32))
    print("[*] Embeddings L2-normalized")

    # 2) PCA -> PCA_DIM
    print(f"[*] Running PCA -> {PCA_DIM} components ...")
    pca = PCA(n_components=PCA_DIM, whiten=WHITEN, random_state=SEED)
    pca_feats = pca.fit_transform(emb_norm)   # shape (N, PCA_DIM)
    print("    PCA done. Explained variance sum:", float(np.sum(pca.explained_variance_ratio_)))

    # 3) standardize PCA outputs (zero mean unit var) then min-max to [-1,1]
    scaler = StandardScaler()
    pca_std = scaler.fit_transform(pca_feats)  # zero mean, unit var
    # min-max to [-1,1]
    pmin = pca_std.min(axis=0)
    pmax = pca_std.max(axis=0)
    prange = np.where((pmax - pmin) == 0, 1.0, (pmax - pmin))
    pca_scaled = 2.0 * (pca_std - pmin) / prange - 1.0
    print("[*] PCA standardized and scaled to [-1,1]")

    # 4) pattern score (OWASP-like) on query + uri
    print("[*] Computing pattern scores from query+uri ...")
    combined_text = (df["query"].fillna("") + " " + df["uri"].fillna("")).astype(str)
    pattern_scores = combined_text.apply(pattern_score).values  # 0..PATTERN_MAX_SCORE
    pattern_scores_norm = np.clip(pattern_scores / PATTERN_MAX_SCORE, 0.0, 1.0)

    # 5) numeric rule features -> min-max [0,1]
    rule_hit = pd.to_numeric(df["rule_hit_count"], errors="coerce").fillna(0.0).astype(float).values
    rule_sev = pd.to_numeric(df["rule_severity_max"], errors="coerce").fillna(0.0).astype(float).values
    rule_numeric = np.vstack([rule_hit, rule_sev]).T
    rule_scaler = MinMaxScaler(feature_range=(0.0, 1.0))
    # safe fit: if constant columns, MinMaxScaler will produce zeros
    try:
        rule_numeric_scaled = rule_scaler.fit_transform(rule_numeric)
    except Exception:
        # fallback to manual normalization
        rule_numeric_scaled = np.zeros_like(rule_numeric)
        for i in range(rule_numeric.shape[1]):
            col = rule_numeric[:, i]
            mn, mx = col.min(), col.max()
            den = mx - mn if (mx - mn) > 1e-12 else 1.0
            rule_numeric_scaled[:, i] = (col - mn) / den

    rule_numeric_score = np.max(rule_numeric_scaled, axis=1)  # use max of the two normalized numeric signals

    # 6) combine numeric rule score and pattern score into rule_score [0,1]
    rule_score = np.clip(0.6 * rule_numeric_score + 0.4 * pattern_scores_norm, 0.0, 1.0)

    # 7) normalize reconstruction_error -> [0,1]
    recon = pd.to_numeric(df["reconstruction_error"], errors="coerce").fillna(0.0).astype(float).values
    rmin, rmax = float(recon.min()), float(recon.max())
    if (rmax - rmin) < 1e-12:
        recon_norm = np.zeros_like(recon)
    else:
        recon_norm = (recon - rmin) / (rmax - rmin)
    recon_norm = np.clip(recon_norm, 0.0, 1.0)

    # 8) scale recon_norm and rule_score to [-1,1] for NN stability (optional but matches previous pipeline)
    recon_scaled = 2.0 * recon_norm - 1.0
    rule_scaled = 2.0 * rule_score - 1.0

    # 9) concatenate to produce observations (N x 22)
    print("[*] Concatenating features into observations (PCA20 + recon + rule) ...")
    observations = np.zeros((N, PCA_DIM + 2), dtype=np.float32)
    observations[:, :PCA_DIM] = pca_scaled.astype(np.float32)
    observations[:, PCA_DIM] = recon_scaled.astype(np.float32)
    observations[:, PCA_DIM + 1] = rule_scaled.astype(np.float32)

    # 10) save observations.npy
    np.save(OUT_OBS, observations)
    print(f"[+] Saved observations to: {OUT_OBS}  (shape {observations.shape})")

    # 11) optionally persist PCA + scaler + rule_scaler + normalization params for inference
    try:
        joblib.dump({
            "pca": pca,
            "std_scaler": scaler,
            "pca_min": pmin,
            "pca_max": pmax,
            "rule_scaler": rule_scaler,
            "recon_min": rmin,
            "recon_max": rmax,
        }, OUT_PCA_PICKLE)
        print(f"[+] Saved PCA + scaler objects to: {OUT_PCA_PICKLE}")
    except Exception as e:
        print("Warning: unable to save PCA objects:", e)

    print("[*] Done.")

if __name__ == "__main__":
    main()


[*] Loading embeddings...
    embeddings shape: (7000, 768)
[*] Loading logs CSV...
[*] Loading reconstruction errors CSV...
[*] Embeddings L2-normalized
[*] Running PCA -> 30 components ...
    PCA done. Explained variance sum: 0.9218317270278931
[*] PCA standardized and scaled to [-1,1]
[*] Computing pattern scores from query+uri ...
[*] Concatenating features into observations (PCA20 + recon + rule) ...
[+] Saved observations to: observations_22d.npy  (shape (7000, 32))
[+] Saved PCA + scaler objects to: pca_and_scaler.pkl
[*] Done.


In [4]:
#!/usr/bin/env python3
"""
ppo_pipeline_complete.py

Single-file pipeline (train + eval + save/load) for custom PPO on 22-d WAF features (PCA20 + recon + rule).

Prereqs:
  pip install torch numpy pandas scikit-learn gym

Files expected in working dir:
  - observations_22d.npy        (shape [N,22])
  - synthetic_nginx_logs.csv    (must contain 'label' column)
Optional:
  - reconstruction_errors.csv (only needed if you want to re-create observations)

Outputs:
  - ppo_policy_final.pt
  - ppo_policy_best.pt
  - logs_with_policy_actions.csv
"""

import os
import time
import random
from typing import Tuple, Dict, Any

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
import gym
from gym import spaces

# ----------------------------
# CONFIG (tweak as needed)
# ----------------------------
OBS_NPY = "observations_22d.npy"
LOG_CSV = "synthetic_nginx_logs.csv"

OUT_MODEL_FINAL = "ppo_policy_final.pt"
OUT_MODEL_BEST = "ppo_policy_best.pt"
OUT_CSV = "logs_with_policy_actions.csv"

SEED = 42
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# training hyperparams
TOTAL_EPOCHS = 20
NUM_STEPS = 4096
PPO_EPOCHS = 6
MINI_BATCH_SIZE = 128
GAMMA = 0.995
GAE_LAMBDA = 0.95
CLIP_EPS = 0.2
LR = 3e-4
VALUE_COEF = 0.5
ENTROPY_COEF = 0.01
MAX_GRAD_NORM = 0.5

# env & reward shaping
EPISODE_LEN = 512
ATTACK_BIAS = 0.9
HIGH_THRESH = 0.55
MID_THRESH = 0.30

HIDDEN_SIZES = [128, 128]
VERBOSE = True

# reproducibility
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if DEVICE == "cuda":
    torch.cuda.manual_seed_all(SEED)


# ----------------------------
# Utilities / model helpers
# ----------------------------
class ActorCritic(nn.Module):
    def __init__(self, obs_dim: int, action_dim: int = 2, hidden: Tuple[int, ...] = (128, 128)):
        super().__init__()
        layers = []
        prev = obs_dim
        for h in hidden:
            layers.append(nn.Linear(prev, h))
            layers.append(nn.ReLU())
            prev = h
        self.shared = nn.Sequential(*layers)
        self.pi = nn.Linear(prev, action_dim)
        self.v = nn.Linear(prev, 1)

    def forward(self, x: torch.Tensor):
        h = self.shared(x)
        return self.pi(h), self.v(h).squeeze(-1)


class PPOAgent:
    def __init__(self, obs_dim: int, lr: float = LR, hidden=(128, 128)):
        self.net = ActorCritic(obs_dim, action_dim=2, hidden=hidden).to(DEVICE)
        self.opt = optim.Adam(self.net.parameters(), lr=lr)

    def act(self, obs_np: np.ndarray) -> Tuple[int, float, float]:
        t = torch.tensor(obs_np, dtype=torch.float32, device=DEVICE).unsqueeze(0)
        logits, val = self.net(t)
        probs = torch.softmax(logits, dim=-1)
        dist = Categorical(probs)
        a = dist.sample()
        return int(a.item()), float(dist.log_prob(a).item()), float(val.item())

    def get_logits_values(self, batch_obs: np.ndarray):
        t = torch.tensor(batch_obs, dtype=torch.float32, device=DEVICE)
        logits, vals = self.net(t)
        return logits, vals

    def update(self, obs_buf, act_buf, adv_buf, ret_buf, logp_buf,
               ppo_epochs=PPO_EPOCHS, mini_batch=MINI_BATCH_SIZE):
        obs = torch.tensor(obs_buf, dtype=torch.float32, device=DEVICE)
        acts = torch.tensor(act_buf, dtype=torch.long, device=DEVICE)
        old_logp = torch.tensor(logp_buf, dtype=torch.float32, device=DEVICE)
        advs = torch.tensor(adv_buf, dtype=torch.float32, device=DEVICE)
        rets = torch.tensor(ret_buf, dtype=torch.float32, device=DEVICE)

        advs = (advs - advs.mean()) / (advs.std() + 1e-9)
        N = obs.shape[0]
        idxs = np.arange(N)
        for _ in range(ppo_epochs):
            np.random.shuffle(idxs)
            for start in range(0, N, mini_batch):
                mb = idxs[start:start + mini_batch]
                mb_obs = obs[mb]
                mb_acts = acts[mb]
                mb_advs = advs[mb]
                mb_rets = rets[mb]
                mb_old_logp = old_logp[mb]

                logits, vals = self.net(mb_obs)
                probs = torch.softmax(logits, dim=-1)
                dist = Categorical(probs)
                mb_logp = dist.log_prob(mb_acts)
                entropy = dist.entropy().mean()

                ratio = torch.exp(mb_logp - mb_old_logp)
                surr1 = ratio * mb_advs
                surr2 = torch.clamp(ratio, 1.0 - CLIP_EPS, 1.0 + CLIP_EPS) * mb_advs
                policy_loss = -torch.min(surr1, surr2).mean()
                value_loss = ((vals - mb_rets) ** 2).mean()

                loss = policy_loss + VALUE_COEF * value_loss - ENTROPY_COEF * entropy

                self.opt.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.net.parameters(), MAX_GRAD_NORM)
                self.opt.step()

    def save(self, path: str):
        torch.save(self.net.state_dict(), path)

    def load(self, path: str):
        self.net.load_state_dict(torch.load(path, map_location=DEVICE))


# ----------------------------
# Environment (weighted sampling)
# ----------------------------
class WafPPOEnvRandom(gym.Env):
    def __init__(self, observations: np.ndarray, labels: np.ndarray,
                 episode_len: int = EPISODE_LEN, attack_bias: float = ATTACK_BIAS):
        super().__init__()
        self.observations = observations.astype(np.float32)
        self.labels = labels.astype(int)
        self.n = observations.shape[0]
        self.episode_len = int(episode_len)
        alpha = 1.0 + 50.0 * float(np.clip(attack_bias, 0.0, 1.0))
        weights = np.ones(self.n, dtype=np.float32)
        weights[self.labels == 1] *= alpha
        self.sampling_weights = weights / np.sum(weights)
        self.current_step = 0
        self.current_idx = 0
        self.observation_space = spaces.Box(low=-5.0, high=5.0, shape=(self.observations.shape[1],), dtype=np.float32)
        self.action_space = spaces.Discrete(2)
        self._seed = None
        self.centroid_cached = None

    def seed(self, seed=None):
        if seed is None:
            seed = np.random.randint(0, 2 ** 31 - 1)
        self._seed = int(seed)
        np.random.seed(self._seed)
        random.seed(self._seed)
        return [self._seed]

    def reset(self, *, seed=None, options=None):
        if seed is not None:
            self.seed(int(seed))
        self.current_step = 0
        self.current_idx = int(np.random.choice(self.n, p=self.sampling_weights))
        return self.observations[self.current_idx].copy(), {}

    def _ensure_centroid(self):
        # centroid for pseudo-embedding anomaly computed on-the-fly for speed
        if self.centroid_cached is None:
            benign_mask = (self.labels == 0)
            if benign_mask.sum() == 0:
                self.centroid_cached = np.zeros(self.observations.shape[1] - 2, dtype=np.float32)
            else:
                self.centroid_cached = self.observations[benign_mask, :-2].mean(axis=0)
                norm = np.linalg.norm(self.centroid_cached)
                if norm > 0:
                    self.centroid_cached = self.centroid_cached / norm

    def step(self, action):
        obs = self.observations[self.current_idx]
        recon_scaled = float(obs[-2]); rule_scaled = float(obs[-1])
        recon_val = (recon_scaled + 1.0) / 2.0
        rule_val = (rule_scaled + 1.0) / 2.0

        self._ensure_centroid()
        emb_vec = obs[:-2].astype(np.float32)
        denom = (np.linalg.norm(emb_vec) + 1e-9)
        cos = float(np.dot(emb_vec, self.centroid_cached) / denom)
        emb_anom = float((1.0 - cos) / 2.0)

        combined_anom = 0.6 * recon_val + 0.3 * rule_val + 0.1 * emb_anom
        combined_anom = float(np.clip(combined_anom, 0.0, 1.0))

        # stronger shaped reward
        if combined_anom >= HIGH_THRESH:
            reward = 3.0 if action == 1 else -2.0
        elif combined_anom >= MID_THRESH:
            reward = 1.0 if action == 1 else -0.8
        else:
            reward = 1.0 if action == 0 else -0.9

        self.current_idx = int(np.random.choice(self.n, p=self.sampling_weights))
        self.current_step += 1
        done = self.current_step >= self.episode_len
        next_obs = self.observations[self.current_idx].copy() if not done else np.zeros_like(self.observations[0], dtype=np.float32)
        return next_obs, float(reward), done, {}

    def render(self, mode="human"):
        pass

    def close(self):
        pass


# ----------------------------
# Rollout collection & GAE
# ----------------------------
def collect_rollout(agent: PPOAgent, env: WafPPOEnvRandom, num_steps: int):
    obs_buf, act_buf, rew_buf, val_buf, logp_buf, done_buf = [], [], [], [], [], []
    obs, _ = env.reset()
    for _ in range(num_steps):
        a, logp, val = agent.act(obs)
        next_obs, rew, done, _ = env.step(a)
        obs_buf.append(obs.copy()); act_buf.append(a); rew_buf.append(rew)
        val_buf.append(val); logp_buf.append(logp); done_buf.append(done)
        obs = next_obs
        if done:
            obs, _ = env.reset()
    return (np.array(obs_buf), np.array(act_buf), np.array(rew_buf), np.array(val_buf), np.array(logp_buf), np.array(done_buf))


def compute_gae(rewards, values, dones, last_val, gamma=GAMMA, lam=GAE_LAMBDA):
    T = len(rewards)
    adv = np.zeros(T, dtype=np.float32)
    last_gae = 0.0
    for t in reversed(range(T)):
        if t == T - 1:
            next_non_term = 0.0 if dones[t] else 1.0
            next_val = last_val
        else:
            next_non_term = 0.0 if dones[t + 1] else 1.0
            next_val = values[t + 1]
        delta = rewards[t] + gamma * next_val * next_non_term - values[t]
        last_gae = delta + gamma * lam * next_non_term * last_gae
        adv[t] = last_gae
    returns = adv + values
    return adv, returns


# ----------------------------
# Offline evaluation
# ----------------------------
def evaluate_policy(agent: PPOAgent, observations: np.ndarray, labels: np.ndarray) -> Dict[str, Any]:
    actions = []
    agent.net.eval()
    with torch.no_grad():
        for i in range(len(observations)):
            x = torch.tensor(observations[i], dtype=torch.float32, device=DEVICE).unsqueeze(0)
            logits, _ = agent.net(x)
            probs = torch.softmax(logits, dim=-1)
            action = int(torch.argmax(probs, dim=-1).cpu().numpy()[0])
            actions.append(action)
    actions = np.array(actions, dtype=int)
    acc = accuracy_score(labels, actions)
    p, r, f, _ = precision_recall_fscore_support(labels, actions, average=None, zero_division=0)
    attack_prec = float(p[1]) if len(p) > 1 else 0.0
    attack_rec = float(r[1]) if len(r) > 1 else 0.0
    attack_f1 = float(f[1]) if len(f) > 1 else 0.0
    cm = confusion_matrix(labels, actions)
    return {"accuracy": acc, "attack_prec": attack_prec, "attack_rec": attack_rec, "attack_f1": attack_f1, "confusion_matrix": cm, "actions": actions}


# ----------------------------
# Training loop (main)
# ----------------------------
def train(observations: np.ndarray, labels: np.ndarray):
    N, obs_dim = observations.shape
    env = WafPPOEnvRandom(observations, labels, episode_len=EPISODE_LEN, attack_bias=ATTACK_BIAS)
    agent = PPOAgent(obs_dim=obs_dim, lr=LR, hidden=tuple(HIDDEN_SIZES))

    best_f1 = -1.0
    best_path = OUT_MODEL_BEST

    print("[*] Starting training: epochs=", TOTAL_EPOCHS, "NUM_STEPS=", NUM_STEPS)
    start_time = time.time()

    for epoch in range(1, TOTAL_EPOCHS + 1):
        obs_b, act_b, rew_b, val_b, logp_b, done_b = collect_rollout(agent, env, NUM_STEPS)
        with torch.no_grad():
            _, last_val = agent.get_logits_values(obs_b[-1:].astype(np.float32))
            last_val = float(last_val.cpu().numpy()[0])
        adv_b, ret_b = compute_gae(rew_b, val_b, done_b, last_val)
        agent.update(obs_b, act_b, adv_b, ret_b, logp_b)

        avg_rew = float(np.mean(rew_b))
        action_counts = np.bincount(act_b, minlength=2)
        if VERBOSE:
            print(f"[Epoch {epoch}/{TOTAL_EPOCHS}] avg_rew={avg_rew:.4f} actions=ALLOW:{action_counts[0]} BLOCK:{action_counts[1]}")

        # offline evaluation and model saving by attack_f1
        metrics = evaluate_policy(agent, observations, labels)
        if VERBOSE:
            print(f"  Offline: acc={metrics['accuracy']:.4f} attack_f1={metrics['attack_f1']:.4f} attack_rec={metrics['attack_rec']:.4f} attack_prec={metrics['attack_prec']:.4f}")

        if metrics["attack_f1"] > best_f1:
            best_f1 = metrics["attack_f1"]
            agent.save(best_path)
            if VERBOSE:
                print(f"  [+] New best model saved (attack_f1={best_f1:.4f})")

    elapsed = (time.time() - start_time) / 60.0
    print(f"[*] Training finished in {elapsed:.2f} minutes. Best attack_f1={best_f1:.4f}")
    agent.save(OUT_MODEL_FINAL)
    print(f"[*] Final model saved to {OUT_MODEL_FINAL}, best model to {best_path}")
    return agent, best_path


# ----------------------------
# Save / Load convenience
# ----------------------------
def save_model(agent: PPOAgent, path: str):
    agent.save(path)
    print(f"[+] Model saved to {path}")


def load_model(path: str, obs_dim: int) -> PPOAgent:
    agent = PPOAgent(obs_dim=obs_dim, lr=LR, hidden=tuple(HIDDEN_SIZES))
    agent.load(path)
    agent.net.to(DEVICE)
    agent.net.eval()
    print(f"[+] Loaded model from {path}")
    return agent


# ----------------------------
# Main entry
# ----------------------------
def main():
    if not os.path.exists(OBS_NPY):
        raise FileNotFoundError(f"Missing observations file: {OBS_NPY}")
    if not os.path.exists(LOG_CSV):
        raise FileNotFoundError(f"Missing log CSV: {LOG_CSV}")

    observations = np.load(OBS_NPY).astype(np.float32)
    df_logs = pd.read_csv(LOG_CSV, dtype=str).fillna("")
    if "label" not in df_logs.columns:
        df_logs["label"] = 0
    df_logs["label"] = pd.to_numeric(df_logs["label"], errors="coerce").fillna(0).astype(int)
    labels = df_logs["label"].values

    print("[*] Data loaded. N=", observations.shape[0], "obs_dim=", observations.shape[1])
    agent, best_path = train(observations, labels)

    # final evaluation using the final agent
    final_metrics = evaluate_policy(agent, observations, labels)
    print("\n=== Final Offline Metrics (final model) ===")
    print(f"Accuracy: {final_metrics['accuracy']:.4f}")
    print(f"Attack precision: {final_metrics['attack_prec']:.4f}")
    print(f"Attack recall   : {final_metrics['attack_rec']:.4f}")
    print(f"Attack F1       : {final_metrics['attack_f1']:.4f}")
    print("Confusion matrix:\n", final_metrics["confusion_matrix"])

    # save actions to CSV
    actions = final_metrics["actions"]
    df_out = df_logs.copy()
    df_out["policy_action"] = actions
    df_out.to_csv(OUT_CSV, index=False)
    print(f"[+] Saved actions to {OUT_CSV}")

    # also show best model path
    print(f"[+] Best model (by attack_f1) stored at: {best_path}")


if __name__ == "__main__":
    main()


[*] Data loaded. N= 7000 obs_dim= 32
[*] Starting training: epochs= 20 NUM_STEPS= 4096
[Epoch 1/20] avg_rew=0.0534 actions=ALLOW:1969 BLOCK:2127
  Offline: acc=0.7989 attack_f1=0.5272 attack_rec=0.3925 attack_prec=0.8027
  [+] New best model saved (attack_f1=0.5272)
[Epoch 2/20] avg_rew=0.1479 actions=ALLOW:2118 BLOCK:1978
  Offline: acc=0.8047 attack_f1=0.5278 attack_rec=0.3820 attack_prec=0.8536
  [+] New best model saved (attack_f1=0.5278)
[Epoch 3/20] avg_rew=0.2406 actions=ALLOW:2230 BLOCK:1866
  Offline: acc=0.8037 attack_f1=0.4982 attack_rec=0.3410 attack_prec=0.9241
[Epoch 4/20] avg_rew=0.3638 actions=ALLOW:2434 BLOCK:1662
  Offline: acc=0.8177 attack_f1=0.5663 attack_rec=0.4165 attack_prec=0.8843
  [+] New best model saved (attack_f1=0.5663)
[Epoch 5/20] avg_rew=0.4245 actions=ALLOW:2347 BLOCK:1749
  Offline: acc=0.8189 attack_f1=0.5582 attack_rec=0.4005 attack_prec=0.9207
[Epoch 6/20] avg_rew=0.4758 actions=ALLOW:2451 BLOCK:1645
  Offline: acc=0.8217 attack_f1=0.5691 attack_r