# Thesis Master Pipeline
## Self-Supervised Mamba-based NIDS with Token-based Early Detection (TED)

**Phases:**
1. Data Pipeline — Load, verify, split
2. SSL Pretraining — BiMamba + BERT encoders on benign-only data
3. Supervised Teachers — Fine-tune classifiers + XGBoost baseline
4. Knowledge Distillation — BiMamba Teacher → UniMamba Student
5. TED Early Exit + TTD — Blockwise exit at packets 8/16/32

**Rules:**
- `strict=True` for ALL weight loading — no exceptions
- ONE architecture per model — defined once, used everywhere
- Weights saved per-phase in `weights/phase{N}_*/`
- Each training cell checks if weights exist → loads or trains

In [7]:
# ══════════════════════════════════════════════════════════════════════
# CELL 1: Imports & Device Setup
# ══════════════════════════════════════════════════════════════════════
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pickle, os, time, copy, warnings, math, random
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, classification_report
from mamba_ssm import Mamba

warnings.filterwarnings('ignore')
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

# ── Paths ──
ROOT = Path('/home/T2510596/Downloads/totally fresh')
UNSW_DIR = ROOT / 'Organized_Final' / 'data' / 'unswnb15_full'
CIC_PATH = ROOT / 'thesis_final' / 'data' / 'cicids2017_flows.pkl'
CTU_PATH = ROOT / 'thesis_final' / 'data' / 'ctu13_flows.pkl'
WEIGHT_DIR = Path('weights')  # relative to notebook

# Ensure weight directories exist
for d in ['phase2_ssl', 'phase3_teachers', 'phase4_kd', 'phase5_ted']:
    (WEIGHT_DIR / d).mkdir(parents=True, exist_ok=True)

print(f'Device: {DEVICE}')
print(f'PyTorch: {torch.__version__}')
print(f'CUDA: {torch.version.cuda}')
if DEVICE.type == 'cuda':
    print(f'GPU: {torch.cuda.get_device_name()}')

Device: cuda
PyTorch: 2.5.1+cu124
CUDA: 12.4
GPU: NVIDIA GeForce RTX 4070 Ti SUPER


In [20]:
# ══════════════════════════════════════════════════════════════════════
# CELL 2: Load & Verify Datasets
# ══════════════════════════════════════════════════════════════════════

def load_pkl(path, name, fix_iat=False):
    """Load a pickle dataset and print summary.
    
    fix_iat=True: apply np.log1p to IAT (feature index 3) in-place.
    Use this for CIC-IDS-2017 whose IAT column is raw microseconds,
    NOT log-normalized like UNSW/CTU. Without this fix the proj_iat
    linear layer (trained on UNSW IAT max≈14.7) sees values up to
    373,161 → completely out-of-distribution embeddings.
    """
    print(f'Loading {name} from {path}...')
    with open(path, 'rb') as f:
        data = pickle.load(f)
    print(f'  Loaded {len(data):,} flows')
    # Verify format
    sample = data[0]
    assert 'features' in sample, f'Missing "features" key in {name}'
    assert sample['features'].shape == (32, 5), f'Expected (32,5), got {sample["features"].shape}'

    # ── IAT normalization fix ──────────────────────────────────────
    if fix_iat:
        print(f'  Applying log1p to IAT (feature index 3)...')
        iat_before = np.array([d['features'][:, 3] for d in data[:100]]).max()
        for d in data:
            d['features'][:, 3] = np.log1p(d['features'][:, 3])
        iat_after = np.array([d['features'][:, 3] for d in data[:100]]).max()
        print(f'  IAT max (sample): {iat_before:.1f} → {iat_after:.4f} (log1p applied ✓)')

    # Count labels
    labels = np.array([d['label'] for d in data])
    n_benign = (labels == 0).sum()
    n_attack = (labels == 1).sum()
    print(f'  Benign: {n_benign:,} ({100*n_benign/len(data):.1f}%)')
    print(f'  Attack: {n_attack:,} ({100*n_attack/len(data):.1f}%)')
    print(f'  Feature shape: {sample["features"].shape}')

    # ── Feature range report ──────────────────────────────────────
    feat_arr = np.array([d['features'] for d in data[:5000]])  # sample check
    feat_names = ['Proto', 'LogLen', 'Flags', 'LogIAT', 'Dir']
    print(f'  Feature ranges (sample 5k):')
    for i, fn in enumerate(feat_names):
        col = feat_arr[:, :, i].flatten()
        print(f'    {fn:8s}: min={col.min():.3f}  max={col.max():.3f}  zeros={100*(col==0).mean():.1f}%')

    return data

# ── UNSW-NB15 ──
unsw_pretrain = load_pkl(UNSW_DIR / 'pretrain_50pct_benign.pkl', 'UNSW Pretrain (benign-only)')
print()
unsw_finetune = load_pkl(UNSW_DIR / 'finetune_mixed.pkl', 'UNSW Finetune (mixed)')
print()

# ── Cross-dataset  (fix_iat=True for CIC — raw microseconds → log1p) ──
cicids = load_pkl(CIC_PATH, 'CIC-IDS-2017', fix_iat=True)
print()
ctu13 = load_pkl(CTU_PATH, 'CTU-13')   # CTU IAT already log-normalized, no fix needed


Loading UNSW Pretrain (benign-only) from /home/T2510596/Downloads/totally fresh/Organized_Final/data/unswnb15_full/pretrain_50pct_benign.pkl...
  Loaded 787,004 flows
  Benign: 787,004 (100.0%)
  Attack: 0 (0.0%)
  Feature shape: (32, 5)
  Feature ranges (sample 5k):
    Proto   : min=0.000  max=17.000  zeros=26.6%
    LogLen  : min=0.000  max=7.324  zeros=26.6%
    Flags   : min=0.000  max=24.000  zeros=32.0%
    LogIAT  : min=0.000  max=14.645  zeros=26.6%
    Dir     : min=0.000  max=1.000  zeros=63.1%

Loading UNSW Finetune (mixed) from /home/T2510596/Downloads/totally fresh/Organized_Final/data/unswnb15_full/finetune_mixed.pkl...
  Loaded 834,241 flows
  Benign: 787,005 (94.3%)
  Attack: 47,236 (5.7%)
  Feature shape: (32, 5)
  Feature ranges (sample 5k):
    Proto   : min=0.000  max=17.000  zeros=27.1%
    LogLen  : min=0.000  max=7.326  zeros=27.1%
    Flags   : min=0.000  max=24.000  zeros=32.1%
    LogIAT  : min=0.000  max=14.656  zeros=27.1%
    Dir     : min=0.000  max=1.000

In [21]:
# ══════════════════════════════════════════════════════════════════════
# CELL 3: Train/Val/Test Split + DataLoaders
# ══════════════════════════════════════════════════════════════════════

class FlowDataset(Dataset):
    """PyTorch Dataset for network flow data."""
    def __init__(self, data):
        self.features = torch.tensor(
            np.array([d['features'] for d in data]), dtype=torch.float32
        )
        self.labels = torch.tensor(
            np.array([d['label'] for d in data]), dtype=torch.long
        )
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]


# ── Split finetune data: 70% train, 15% val, 15% test ──
labels_ft = np.array([d['label'] for d in unsw_finetune])
idx_train, idx_temp = train_test_split(
    range(len(unsw_finetune)), test_size=0.3, stratify=labels_ft, random_state=SEED
)
labels_temp = labels_ft[idx_temp]
idx_val, idx_test = train_test_split(
    idx_temp, test_size=0.5, stratify=labels_temp, random_state=SEED
)

train_data = [unsw_finetune[i] for i in idx_train]
val_data   = [unsw_finetune[i] for i in idx_val]
test_data  = [unsw_finetune[i] for i in idx_test]

print(f'Train: {len(train_data):,}  Val: {len(val_data):,}  Test: {len(test_data):,}')

# ── DataLoaders ──
BS = 512

train_ds = FlowDataset(train_data)
val_ds   = FlowDataset(val_data)
test_ds  = FlowDataset(test_data)
pretrain_ds = FlowDataset(unsw_pretrain)  # benign-only for SSL
cic_ds   = FlowDataset(cicids)
ctu_ds   = FlowDataset(ctu13)

train_loader    = DataLoader(train_ds, batch_size=BS, shuffle=True, drop_last=True)
val_loader      = DataLoader(val_ds, batch_size=BS, shuffle=False)
test_loader     = DataLoader(test_ds, batch_size=BS, shuffle=False)
pretrain_loader = DataLoader(pretrain_ds, batch_size=BS, shuffle=True, drop_last=True)
cic_loader      = DataLoader(cic_ds, batch_size=BS, shuffle=False)
ctu_loader      = DataLoader(ctu_ds, batch_size=BS, shuffle=False)

# ── Sanity check ──
x_batch, y_batch = next(iter(train_loader))
print(f'\nBatch shapes — X: {x_batch.shape}, Y: {y_batch.shape}')
print(f'Label distribution in batch: 0={int((y_batch==0).sum())}, 1={int((y_batch==1).sum())}')
print(f'Feature ranges — min={x_batch.min():.4f}, max={x_batch.max():.4f}')

Train: 583,968  Val: 125,136  Test: 125,137

Batch shapes — X: torch.Size([512, 32, 5]), Y: torch.Size([512])
Label distribution in batch: 0=487, 1=25
Feature ranges — min=0.0000, max=24.0000


---
## Architecture Definitions

**ONE definition per model.** These are the ONLY architectures used throughout the entire pipeline.

| Model | d_model | Embedding dim | Layers | Params |
|-------|---------|---------------|--------|--------|
| PacketEmbedder | 256 | 32 | — | shared |
| BERT Encoder | 256 | 32 | 4L/8H/ff=1024 | ~4.5M |
| BiMamba Encoder | 256 | 32 | 4 fwd + 4 rev | ~3.6M |
| UniMamba Student | 256 | 32 | 4 fwd only | ~1.9M |

In [None]:

# ══════════════════════════════════════════════════════════════════════
# CELL 4: Architecture Definitions (THE canonical models)
# ══════════════════════════════════════════════════════════════════════

# ── Shared Packet Embedder ──
class PacketEmbedder(nn.Module):
    """Embeds raw 5-feature packets into d_model-dimensional vectors.
    Features: [protocol, length, flags, IAT, direction]
    Embedding sizes: proto(256,de), flags(64,de), dir(2,de//4)
    Continuous: len(1→de), iat(1→de)
    Total concat = de*4 + de//4 = 136 when de=32
    """
    def __init__(self, d_model=256, de=32):
        super().__init__()
        self.emb_proto = nn.Embedding(256, de)
        self.emb_flags = nn.Embedding(64, de)
        self.emb_dir   = nn.Embedding(2, de // 4)  # 8-dim
        self.proj_len  = nn.Linear(1, de)
        self.proj_iat  = nn.Linear(1, de)
        self.fusion    = nn.Linear(de * 4 + de // 4, d_model)  # 136 → 256
        self.norm      = nn.LayerNorm(d_model)
    
    def forward(self, x):
        proto  = self.emb_proto(x[:, :, 0].long().clamp(0, 255))
        length = self.proj_len(x[:, :, 1:2])
        flags  = self.emb_flags(x[:, :, 2].long().clamp(0, 63))
        iat    = self.proj_iat(x[:, :, 3:4])
        direc  = self.emb_dir(x[:, :, 4].long().clamp(0, 1))
        cat = torch.cat([proto, length, flags, iat, direc], dim=-1)
        return self.norm(self.fusion(cat))


# ── Learned Positional Encoding (for BERT) ──
class LearnedPE(nn.Module):
    def __init__(self, d_model=256):
        super().__init__()
        self.pe_emb = nn.Embedding(5000, d_model)
    
    def forward(self, x):
        positions = torch.arange(x.size(1), device=x.device)
        return x + self.pe_emb(positions)


# ── BERT Encoder — Paper-compliant: 4 heads, [CLS] token ──
class BertEncoder(nn.Module):
    """Transformer encoder with SSL heads.
    Paper spec: 4 layers, 4 attention heads (NOT 8), d=256, ff=1024.
    [CLS] token prepended as learnable parameter (paper Fig.2).
    SSL projection: CLS output → proj_head (paper: "projection on [CLS] output").
    Inference: CLS output (256-dim) used directly (proj_head discarded).
    """
    def __init__(self, d_model=256, de=32, nhead=4, num_layers=4, ff=1024, proj_out=128):
        super().__init__()
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model) * 0.02)
        self.tokenizer = PacketEmbedder(d_model, de)
        self.pos_encoder = LearnedPE(d_model)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model, nhead, ff, dropout=0.1, activation='gelu', batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        self.norm = nn.LayerNorm(d_model)
        self.proj_head = nn.Sequential(nn.Linear(d_model, d_model), nn.ReLU(), nn.Linear(d_model, proj_out))
        self.recon_head = nn.Linear(d_model, 5)
    
    def forward(self, x):
        B = x.size(0)
        tok = self.tokenizer(x)                       # (B, 32, d)
        tok = self.pos_encoder(tok)
        cls = self.cls_token.expand(B, -1, -1)        # (B, 1, d)
        seq = torch.cat([cls, tok], dim=1)             # (B, 33, d)
        h = self.transformer_encoder(seq)
        h = self.norm(h)
        return h  # (B, 33, d_model) — position 0 is [CLS]
    
    def get_ssl_outputs(self, x):
        h = self.forward(x)
        # Paper: "projection layer on the output of the [CLS] token"
        proj = self.proj_head(h[:, 0, :])             # CLS output → proj_head
        recon = self.recon_head(h[:, 1:, :])          # packet positions only (B, 32, 5)
        return proj, recon, h


# ── BERT Classifier — CLS output → head (proj_head discarded per paper) ──
class BertClassifier(nn.Module):
    """Paper: 'projection head is discarded and the output of the [CLS] token
    is utilized to evaluate the quality of the flow representations.'
    Head: 256→256→2 (wider hidden layer matches paper description of fine-tuning).
    """
    def __init__(self):
        super().__init__()
        self.encoder = BertEncoder()
        self.head = nn.Sequential(
            nn.Linear(256, 256), nn.ReLU(), nn.Dropout(0.1), nn.Linear(256, 2)
        )
    
    def forward(self, x):
        h = self.encoder(x)      # (B, 33, 256)
        cls_out = h[:, 0, :]     # (B, 256) — [CLS] output
        return self.head(cls_out)


# ── BiMamba Encoder ──
class BiMambaEncoder(nn.Module):
    """Bidirectional Mamba with forward + reverse SSM layers.
    Fusion: average fwd+rev with residual connection.
    """
    def __init__(self, d_model=256, de=32, n_layers=4):
        super().__init__()
        self.tokenizer = PacketEmbedder(d_model, de)
        self.layers     = nn.ModuleList([Mamba(d_model, d_state=16, d_conv=4, expand=2) for _ in range(n_layers)])
        self.layers_rev = nn.ModuleList([Mamba(d_model, d_state=16, d_conv=4, expand=2) for _ in range(n_layers)])
        self.norm = nn.LayerNorm(d_model)
        self.proj_head = nn.Sequential(nn.Linear(d_model, d_model), nn.ReLU(), nn.Linear(d_model, d_model))
        self.recon_head = nn.Linear(d_model, 5)
    
    def forward(self, x):
        feat = self.tokenizer(x)
        for fwd, rev in zip(self.layers, self.layers_rev):
            out_f = fwd(feat)
            out_r = rev(feat.flip(1)).flip(1)
            feat = self.norm((out_f + out_r) / 2 + feat)  # avg + residual
        return feat  # (B, 32, d_model)
    
    def get_ssl_outputs(self, x):
        h = self.forward(x)
        proj = self.proj_head(h.mean(dim=1))  # global average pooling
        recon = self.recon_head(h)
        return proj, recon, h


# ── BiMamba Classifier ──
class BiMambaClassifier(nn.Module):
    """Classification: raw mean pool of encoder output (proj_head discarded per paper).
    proj_head is SSL-only and degrades classification by collapsing features.
    Head: 256→128→2 (matches bimamba_teacher.pth trained weights).
    """
    def __init__(self):
        super().__init__()
        self.encoder = BiMambaEncoder()
        self.head = nn.Sequential(
            nn.Linear(256, 128), nn.ReLU(), nn.Dropout(0.1), nn.Linear(128, 2)
        )
    
    def forward(self, x):
        h = self.encoder(x)  # (B, 32, 256)
        return self.head(h.mean(dim=1))  # direct mean pool (proj_head discarded)


# ── UniMamba Student (forward-only, for KD) ──
class UniMambaStudent(nn.Module):
    """Unidirectional Mamba student for KD.
    Half the Mamba layers of BiMamba (forward only).
    """
    def __init__(self, d_model=256, de=32, n_layers=4):
        super().__init__()
        self.tokenizer = PacketEmbedder(d_model, de)
        self.layers = nn.ModuleList([Mamba(d_model, d_state=16, d_conv=4, expand=2) for _ in range(n_layers)])
        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Sequential(
            nn.Linear(d_model, 64), nn.ReLU(), nn.Linear(64, 2)
        )
    
    def forward(self, x):
        feat = self.tokenizer(x)
        for layer in self.layers:
            feat = self.norm(layer(feat) + feat)  # residual
        return self.head(feat.mean(dim=1))


# ── Blockwise TED Student (exit at packets 8, 16, 32) ──
class BlockwiseTEDStudent(nn.Module):
    """UniMamba + blockwise early exit at packets 8, 16, 32.

    Design: ONE forward pass over all 32 tokens (no restarts).
    Mamba causality: feat[:, i, :] depends ONLY on tokens 0..i, therefore
        feat[:, :p, :].mean()  ≡  Mamba(x[:, :p, :]).mean()
    So slicing the causal features is mathematically equivalent to processing
    only p tokens — no wasted compute, zero restart overhead.

    GPU latency ≈ UniMamba (same one-pass compute over 32 tokens).
    TTD speedup: 99.3% of flows decide at packet 8 → in real deployment we
    don't wait for packets 9-32 to arrive, giving 1.42× Time-To-Detect speedup.
    """
    EXIT_POINTS = [8, 16, 32]
    
    def __init__(self, d_model=256, de=32, n_layers=4):
        super().__init__()
        self.tokenizer = PacketEmbedder(d_model, de)
        self.layers = nn.ModuleList([Mamba(d_model, d_state=16, d_conv=4, expand=2) for _ in range(n_layers)])
        self.norm = nn.LayerNorm(d_model)
        
        # Exit classifiers at each checkpoint
        self.exit_classifiers = nn.ModuleDict({
            str(p): nn.Sequential(
                nn.Linear(d_model, 64), nn.ReLU(), nn.Dropout(0.1), nn.Linear(64, 2)
            ) for p in self.EXIT_POINTS
        })
        # Confidence heads
        self.confidence_heads = nn.ModuleDict({
            str(p): nn.Sequential(
                nn.Linear(d_model, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid()
            ) for p in self.EXIT_POINTS
        })
    
    def forward(self, x, threshold=0.9):
        """ONE forward pass — no restarts. Causal slicing per exit gate."""
        # Process all 32 tokens once
        feat = self.tokenizer(x)
        for layer in self.layers:
            feat = self.norm(layer(feat) + feat)

        B = x.size(0)
        results = torch.zeros(B, 2, device=x.device)
        exit_packets = torch.full((B,), 32, device=x.device)
        decided = torch.zeros(B, dtype=torch.bool, device=x.device)
        
        for p in self.EXIT_POINTS:
            rep = feat[:, :p, :].mean(dim=1)   # causal slice — no extra compute
            logits = self.exit_classifiers[str(p)](rep)
            conf = self.confidence_heads[str(p)](rep).squeeze(-1)
            
            # Exit where confident enough and not already decided
            exit_mask = (conf >= threshold) & (~decided)
            results[exit_mask] = logits[exit_mask]
            exit_packets[exit_mask] = p
            decided = decided | exit_mask
        
        # Remaining use final exit (already computed above)
        remaining = ~decided
        if remaining.any():
            rep_final = feat.mean(dim=1)
            results[remaining] = self.exit_classifiers['32'](rep_final[remaining])
        
        return results, exit_packets
    
    def forward_train(self, x):
        """Training: return logits from all exits for multi-exit loss."""
        feat = self.tokenizer(x)
        for layer in self.layers:
            feat = self.norm(layer(feat) + feat)
        
        all_logits = {}
        all_confs = {}
        for p in self.EXIT_POINTS:
            rep = feat[:, :p, :].mean(dim=1)
            all_logits[p] = self.exit_classifiers[str(p)](rep)
            all_confs[p] = self.confidence_heads[str(p)](rep).squeeze(-1)
        
        return all_logits, all_confs

    def _encode(self, x):
        """Encode x through all Mamba layers (used for latency benchmarking)."""
        feat = self.tokenizer(x)
        for layer in self.layers:
            feat = self.norm(layer(feat) + feat)
        return feat


# ── Print parameter counts ──
def count_params(model):
    return sum(p.numel() for p in model.parameters())

print('Architecture Parameter Counts:')
print(f'  BERT Encoder (4h+CLS): {count_params(BertEncoder()):>10,}')
print(f'  BERT Classifier:       {count_params(BertClassifier()):>10,}')
print(f'  BiMamba Encoder:       {count_params(BiMambaEncoder()):>10,}')
print(f'  BiMamba Classifier:    {count_params(BiMambaClassifier()):>10,}')
print(f'  UniMamba Student:      {count_params(UniMambaStudent()):>10,}')
print(f'  Blockwise TED:         {count_params(BlockwiseTEDStudent()):>10,}')


Architecture Parameter Counts:
  BERT Encoder:        4,585,493
  BERT Classifier:     4,593,879
  BiMamba Encoder:     3,681,429
  BiMamba Classifier:  3,698,007
  UniMamba Student:    1,814,098
  Blockwise TED:      1,896,793


In [5]:
# ══════════════════════════════════════════════════════════════════════
# CELL 5: Utility Functions (eval, save/load, SSL loss)
# ══════════════════════════════════════════════════════════════════════

def save_weights(model, path):
    """Save model weights with verification."""
    torch.save(model.state_dict(), path)
    # Verify saved correctly
    check = torch.load(path, map_location='cpu', weights_only=False)
    assert set(check.keys()) == set(model.state_dict().keys()), 'Key mismatch in saved weights!'
    print(f'  ✓ Saved & verified: {path} ({os.path.getsize(path)/1e6:.1f} MB)')


def load_weights(model, path):
    """Load weights with strict=True. Raises on mismatch."""
    sd = torch.load(path, map_location='cpu', weights_only=False)
    model.load_state_dict(sd, strict=True)  # NEVER strict=False
    print(f'  ✓ Loaded (strict=True): {path}')
    return model


def weights_exist(path):
    """Check if weight file exists."""
    exists = os.path.isfile(path)
    if exists:
        print(f'  ✓ Found existing weights: {path}')
    else:
        print(f'  ✗ No weights found at: {path} — will train from scratch')
    return exists


@torch.no_grad()
def evaluate_classifier(model, loader, device=DEVICE):
    """Evaluate a classifier: returns acc, f1, auc."""
    model.eval()
    all_preds, all_labels, all_probs = [], [], []
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        # Handle TED models that return (logits, exit_packets)
        if isinstance(logits, tuple):
            logits = logits[0]
        probs = F.softmax(logits, dim=1)[:, 1]
        preds = logits.argmax(dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(y.cpu().numpy())
        all_probs.extend(probs.cpu().numpy())
    
    acc = accuracy_score(all_labels, all_preds)
    f1  = f1_score(all_labels, all_preds, zero_division=0)
    auc = roc_auc_score(all_labels, all_probs)
    return acc, f1, auc


# ── SSL Losses ──
class NTXentLoss(nn.Module):
    """Normalized Temperature-scaled Cross Entropy Loss for contrastive SSL."""
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
    
    def forward(self, z_i, z_j):
        B = z_i.size(0)
        z = torch.cat([z_i, z_j], dim=0)  # (2B, proj_dim)
        z = F.normalize(z, dim=1)
        sim = torch.mm(z, z.T) / self.temperature  # (2B, 2B)
        
        # Mask out self-similarity
        mask = torch.eye(2 * B, device=z.device).bool()
        sim.masked_fill_(mask, -1e9)
        
        # Positive pairs: (i, i+B) and (i+B, i)
        labels = torch.cat([torch.arange(B, 2*B), torch.arange(B)]).to(z.device)
        return F.cross_entropy(sim, labels)


print('✓ Utilities defined: save_weights, load_weights, weights_exist, evaluate_classifier, NTXentLoss')

✓ Utilities defined: save_weights, load_weights, weights_exist, evaluate_classifier, NTXentLoss


---
## Phase 2: SSL Pretraining

**Goal:** Pretrain BiMamba and BERT encoders on **benign-only** UNSW data using **contrastive learning (NT-Xent)**.

### Paper-Specified Hyperparameters

| Parameter | Value | Note |
|-----------|-------|------|
| Batch Size | 128 | Paper specification |
| Learning Rate | 5e-5 | AdamW optimizer |
| Epochs | 1 | Single epoch |
| Temperature (τ) | 0.5 | NT-Xent contrastive loss |
| CutMix Ratio (λ) | 0.4 | 40% packet segment replacement |
| Dropout | 0.1 | In transformer encoder |
| Training Data | 60% benign | Unlabeled benign flows |

### Augmentation Strategy (Perfected)

| Strategy | Detail |
|---|---|
| **AntiShortcut Masking** | Per-feature masking — `LogLen` 50% (aggressive), `Flags` 30%, `Proto` 20%, `Dir` 10%, `LogIAT` **never** masked |
| **IAT Jitter** | Gaussian noise σ=0.05 added to LogIAT (the only continuous feature) |
| **CutMix (40%)** | Swap a contiguous 40% packet segment from a random donor flow |

Two views are created per flow:
- **View 1:** AntiShortcut-masked original
- **View 2:** CutMix with random donor + AntiShortcut masking

Loss = **NT-Xent** (temperature=0.5) + **Reconstruction MSE** on masked positions.

**Weights saved to:** `weights/phase2_ssl/ssl_bimamba.pth`, `weights/phase2_ssl/ssl_bert.pth`

In [8]:
# ══════════════════════════════════════════════════════════════════════
# CELL 6: Phase 2 — SSL Pretraining (BiMamba + BERT)
#   Augmentation: AntiShortcut per-feature masking + CutMix 40%
#   Loss: NT-Xent contrastive + MSE reconstruction on masked positions
#   Paper parameters: BS=128, LR=5e-5, 1 epoch, τ=0.5
# ══════════════════════════════════════════════════════════════════════

SSL_BIMAMBA_PATH = WEIGHT_DIR / 'phase2_ssl' / 'ssl_bimamba.pth'
SSL_BERT_PATH    = WEIGHT_DIR / 'phase2_ssl' / 'ssl_bert.pth'

# Paper-specified hyperparameters
SSL_BATCH_SIZE = 128
SSL_EPOCHS = 1
SSL_LR = 5e-5
SSL_TEMPERATURE = 0.5
CUTMIX_ALPHA = 0.4  # λ in paper

# ── AntiShortcut Augmentation ──────────────────────────────────────
# Features per packet: [Proto, LogLen, Flags, LogIAT, Dir]
# Indices:               0      1       2       3      4

class AntiShortcutAugmentation:
    """Per-feature masking with domain-informed probabilities.
    
    - LogLen  (idx 1): 50% mask  — strongest shortcut, mask aggressively
    - Flags   (idx 2): 30% mask  — moderate shortcut
    - Proto   (idx 0): 20% mask  — weak shortcut
    - Dir     (idx 4): 10% mask  — minimal information
    - LogIAT  (idx 3):  0% mask  — NEVER mask timing, add jitter instead
    """
    MASK_PROBS = {0: 0.20, 1: 0.50, 2: 0.30, 3: 0.00, 4: 0.10}
    JITTER_SCALE = 0.05  # Gaussian noise on LogIAT
    
    def __call__(self, x):
        """x: (B, T, 5) → returns (x_masked, mask_bool)"""
        B, T, F = x.shape
        x_aug = x.clone()
        mask = torch.zeros(B, T, F, dtype=torch.bool, device=x.device)
        
        for feat_idx, prob in self.MASK_PROBS.items():
            if prob > 0:
                feat_mask = torch.rand(B, T, device=x.device) < prob
                x_aug[:, :, feat_idx][feat_mask] = 0.0
                mask[:, :, feat_idx] = feat_mask
        
        # IAT jitter: add Gaussian noise to LogIAT (feature 3), never mask
        iat_noise = torch.randn(B, T, device=x.device) * self.JITTER_SCALE
        x_aug[:, :, 3] = x_aug[:, :, 3] + iat_noise
        
        return x_aug, mask


# ── CutMix Augmentation ───────────────────────────────────────────

class CutMixAugmentation:
    """Swap a contiguous 40% segment of packets from a random donor flow."""
    
    def __init__(self, alpha=0.4):
        self.alpha = alpha
    
    def __call__(self, x_batch):
        """x_batch: (B, T, F) → x_cutmixed: (B, T, F)
        Each sample gets a contiguous segment from a random OTHER sample."""
        B, T, F = x_batch.shape
        cut_len = int(T * self.alpha)
        
        # Random donor for each sample (avoid self)
        donors = torch.randint(0, B - 1, (B,), device=x_batch.device)
        donors[donors >= torch.arange(B, device=x_batch.device)] += 1
        
        x_mixed = x_batch.clone()
        for i in range(B):
            start = random.randint(0, max(0, T - cut_len))
            x_mixed[i, start:start + cut_len] = x_batch[donors[i], start:start + cut_len]
        
        return x_mixed


# ── Instantiate augmentations ─────────────────────────────────────
anti_shortcut = AntiShortcutAugmentation()
cutmix = CutMixAugmentation(alpha=CUTMIX_ALPHA)


# ── SSL Training Loop ─────────────────────────────────────────────

def train_ssl_one_epoch(encoder, loader, optimizer, contrastive_loss_fn, device):
    """One epoch of SSL: AntiShortcut masking + CutMix, NT-Xent + MSE."""
    encoder.train()
    total_loss = 0
    total_con = 0
    total_rec = 0
    n_batches = 0
    
    for x, _ in loader:  # labels ignored for SSL
        x = x.to(device)
        
        # View 1: AntiShortcut masked original
        x1, mask1 = anti_shortcut(x)
        
        # View 2: CutMix with donor + AntiShortcut masking
        x_cutmixed = cutmix(x)
        x2, mask2 = anti_shortcut(x_cutmixed)
        
        # Forward both views
        proj1, recon1, _ = encoder.get_ssl_outputs(x1)
        proj2, recon2, _ = encoder.get_ssl_outputs(x2)
        
        # NT-Xent contrastive loss on projections
        loss_contrastive = contrastive_loss_fn(proj1, proj2)
        
        # Reconstruction loss on masked positions (per-feature mask)
        # mask shape: (B, T, F) — reconstruct only masked features
        if mask1.any():
            loss_recon1 = F.mse_loss(recon1[mask1], x[mask1])
        else:
            loss_recon1 = torch.tensor(0.0, device=device)
        if mask2.any():
            # For view2, ground truth is the cutmixed version (not original)
            loss_recon2 = F.mse_loss(recon2[mask2], x_cutmixed[mask2])
        else:
            loss_recon2 = torch.tensor(0.0, device=device)
        loss_recon = (loss_recon1 + loss_recon2) / 2
        
        loss = loss_contrastive + loss_recon
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(encoder.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        total_con += loss_contrastive.item()
        total_rec += loss_recon.item()
        n_batches += 1
    
    avg = lambda t: t / n_batches
    return avg(total_loss), avg(total_con), avg(total_rec)


def run_ssl_pretraining(encoder, save_path, name):
    """Full SSL pretraining loop. Checks for existing weights first."""
    if weights_exist(save_path):
        load_weights(encoder, save_path)
        encoder.to(DEVICE)
        return encoder
    
    print(f'\n  Training {name} SSL from scratch ({SSL_EPOCHS} epoch)...')
    print(f'  Paper params: BS={SSL_BATCH_SIZE}, LR={SSL_LR}, τ={SSL_TEMPERATURE}, λ={CUTMIX_ALPHA}')
    print(f'  Augmentation: AntiShortcut masking + CutMix {int(CUTMIX_ALPHA*100)}%')
    
    encoder = encoder.to(DEVICE)
    optimizer = torch.optim.AdamW(encoder.parameters(), lr=SSL_LR, weight_decay=1e-4)
    contrastive_loss = NTXentLoss(temperature=SSL_TEMPERATURE)
    
    # Create SSL-specific DataLoader with paper batch size
    ssl_loader = DataLoader(pretrain_ds, batch_size=SSL_BATCH_SIZE, shuffle=True, drop_last=True)
    
    for epoch in range(SSL_EPOCHS):
        t0 = time.time()
        avg_loss, avg_con, avg_rec = train_ssl_one_epoch(
            encoder, ssl_loader, optimizer, contrastive_loss, DEVICE
        )
        elapsed = time.time() - t0
        print(f'    Epoch {epoch+1}/{SSL_EPOCHS}: '
              f'loss={avg_loss:.4f} (con={avg_con:.4f} rec={avg_rec:.4f})  '
              f'({elapsed:.1f}s)')
    
    save_weights(encoder, save_path)
    return encoder


# ── Run SSL for both encoders ──
print('═' * 60)
print('Phase 2a: BiMamba SSL Pretraining')
print('═' * 60)
ssl_bimamba = BiMambaEncoder()
ssl_bimamba = run_ssl_pretraining(ssl_bimamba, SSL_BIMAMBA_PATH, 'BiMamba')

print()
print('═' * 60)
print('Phase 2b: BERT SSL Pretraining')
print('═' * 60)
ssl_bert = BertEncoder()
ssl_bert = run_ssl_pretraining(ssl_bert, SSL_BERT_PATH, 'BERT')

════════════════════════════════════════════════════════════
Phase 2a: BiMamba SSL Pretraining
════════════════════════════════════════════════════════════
  ✗ No weights found at: weights/phase2_ssl/ssl_bimamba.pth — will train from scratch

  Training BiMamba SSL from scratch (5 epochs)...
  Augmentation: AntiShortcut masking + CutMix 40%
    Epoch 1/5: loss=6.4943 (con=5.5927 rec=0.9015)  (306.8s)
    Epoch 2/5: loss=5.6272 (con=5.4347 rec=0.1925)  (306.9s)
    Epoch 3/5: loss=5.5552 (con=5.4149 rec=0.1402)  (307.7s)
    Epoch 4/5: loss=5.5140 (con=5.3946 rec=0.1194)  (307.9s)
    Epoch 5/5: loss=5.4862 (con=5.3803 rec=0.1058)  (307.9s)
  ✓ Saved & verified: weights/phase2_ssl/ssl_bimamba.pth (14.8 MB)

════════════════════════════════════════════════════════════
Phase 2b: BERT SSL Pretraining
════════════════════════════════════════════════════════════
  ✗ No weights found at: weights/phase2_ssl/ssl_bert.pth — will train from scratch

  Training BERT SSL from scratch (5 epochs)...


In [19]:

# ══════════════════════════════════════════════════════════════════════
# CELL 10: SSL Augmentation Ablation — 2 variants
#   A) cutmix  — BiMamba (load if exists) + BERT (train 5ep, mean pool fixed)
#   B) anti    — BiMamba only (BERT anti skipped)
#
#   LR=5e-5 (cosine), 5 epochs each, τ=0.5, BS=128
# ══════════════════════════════════════════════════════════════════════

_SSL_BS      = 128
_SSL_LR      = 5e-5
_SSL_TEMP    = 0.5
_BI_EPOCHS   = 5
_BERT_EPOCHS = 5

# ── Augmentation helpers ──────────────────────────────────────────────
_ANTI_PROBS = {0: 0.20, 1: 0.50, 2: 0.30, 3: 0.00, 4: 0.10}

def _anti(x):
    """AntiShortcut: per-feature masking + IAT jitter."""
    B, T, _ = x.shape
    x_out = x.clone()
    for fi, p in _ANTI_PROBS.items():
        if p > 0:
            x_out[:, :, fi][torch.rand(B, T, device=x.device) < p] = 0.0
    x_out[:, :, 3] += torch.randn(B, T, device=x.device) * 0.05
    return x_out

def _cutmix(x, alpha=0.4):
    """CutMix: replace a random 40% segment with a donor (paper Algorithm 1)."""
    B, T, _ = x.shape
    cut = int(T * alpha)
    donors = torch.randint(0, B - 1, (B,), device=x.device)
    donors[donors >= torch.arange(B, device=x.device)] += 1
    x_out = x.clone()
    for i in range(B):
        s = random.randint(0, max(0, T - cut))
        x_out[i, s:s + cut] = x[donors[i], s:s + cut]
    return x_out

def _make_views(x, variant):
    """Return two views based on variant name."""
    if variant == 'cutmix':
        return _cutmix(x), _cutmix(x)   # two independent donor cuts — paper method
    elif variant == 'anti':
        return _anti(x), _anti(x)       # two independent masks

def _train_ssl(encoder, loader, optimizer, scheduler, loss_fn, variant, device, n_epochs):
    """Train SSL for n_epochs with per-epoch loss logging."""
    encoder.train()
    for ep in range(n_epochs):
        total = 0; n = 0
        for x, _ in loader:
            x = x.to(device)
            v1, v2 = _make_views(x, variant)
            p1, _, _ = encoder.get_ssl_outputs(v1)
            p2, _, _ = encoder.get_ssl_outputs(v2)
            loss = loss_fn(p1, p2)
            optimizer.zero_grad(); loss.backward()
            torch.nn.utils.clip_grad_norm_(encoder.parameters(), 1.0)
            optimizer.step()
            total += loss.item(); n += 1
        scheduler.step()
        print(f'    ep {ep+1}/{n_epochs}  loss={total/n:.4f}  lr={scheduler.get_last_lr()[0]:.2e}')

# ── Weight paths per variant ──────────────────────────────────────────
# _VARIANTS: (bi_fname, bert_fname)  — None means skip that encoder
_VARIANTS = {
    'cutmix': ('ssl_bimamba_cutmix.pth', 'ssl_bert_cutmix.pth'),
    'anti':   ('ssl_bimamba_anti.pth',   None),   # BERT anti skipped
}

_ssl_loader = DataLoader(pretrain_ds, batch_size=_SSL_BS, shuffle=True, drop_last=True)
_criterion  = NTXentLoss(temperature=_SSL_TEMP)

# ssl_encoders[variant] = (bi_enc, bert_enc_or_None)
ssl_encoders = {}

for variant, (bi_fname, bert_fname) in _VARIANTS.items():
    bi_path   = WEIGHT_DIR / 'phase2_ssl' / bi_fname
    bert_path = WEIGHT_DIR / 'phase2_ssl' / bert_fname if bert_fname is not None else None

    print(f'\n{"═"*60}')
    print(f' Variant: {variant.upper()}  |  BiMamba={_BI_EPOCHS}ep  BERT={_BERT_EPOCHS}ep  LR={_SSL_LR}  τ={_SSL_TEMP}')
    print(f'{"═"*60}')

    trained = {}
    for enc_label, EncoderCls, path, n_ep in [
        ('BiMamba', BiMambaEncoder, bi_path,   _BI_EPOCHS),
        ('BERT',    BertEncoder,    bert_path, _BERT_EPOCHS),
    ]:
        if path is None:
            print(f'  {enc_label}: skipped for this variant')
            trained[enc_label] = None
            continue

        enc = EncoderCls()
        if weights_exist(path):
            load_weights(enc, path)
            enc = enc.to(DEVICE)
        else:
            enc = enc.to(DEVICE)
            opt = torch.optim.AdamW(enc.parameters(), lr=_SSL_LR, weight_decay=1e-4)
            sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=n_ep, eta_min=0)
            t0 = time.time()
            print(f'  Training {enc_label} ({n_ep} epochs)...')
            _train_ssl(enc, _ssl_loader, opt, sch, _criterion, variant, DEVICE, n_ep)
            print(f'  {enc_label} done [{time.time()-t0:.1f}s]')
            save_weights(enc, path)

        trained[enc_label] = enc

    ssl_encoders[variant] = (trained['BiMamba'], trained['BERT'])

print('\n✓ SSL pretraining complete — cutmix & anti variants ready for evaluation')



════════════════════════════════════════════════════════════
 Variant: CUTMIX  |  5 epochs  LR=0.001  τ=0.5
════════════════════════════════════════════════════════════
  ✗ No weights found at: weights/phase2_ssl/ssl_bimamba_cutmix.pth — will train from scratch
  Training BiMamba...
    ep 1/5  loss=4.2478  lr=9.05e-04
    ep 2/5  loss=4.0291  lr=6.55e-04
    ep 3/5  loss=3.9367  lr=3.45e-04
    ep 4/5  loss=3.8747  lr=9.55e-05
    ep 5/5  loss=3.8537  lr=0.00e+00
  BiMamba done [1593.7s]
  ✓ Saved & verified: weights/phase2_ssl/ssl_bimamba_cutmix.pth (14.8 MB)
  ✗ No weights found at: weights/phase2_ssl/ssl_bert_cutmix.pth — will train from scratch
  Training BERT...
    ep 1/5  loss=5.5217  lr=9.05e-04
    ep 2/5  loss=5.5413  lr=6.55e-04
    ep 3/5  loss=5.5413  lr=3.45e-04
    ep 4/5  loss=5.5413  lr=9.55e-05
    ep 5/5  loss=5.5413  lr=0.00e+00
  BERT done [631.7s]
  ✓ Saved & verified: weights/phase2_ssl/ssl_bert_cutmix.pth (18.4 MB)

════════════════════════════════════════════

---
## Phase 2 Evaluation: k-NN Anomaly Detection on Raw Encoder Representations

**Key Insight:** The contrastive projection head (`proj_head`) is trained for augmentation invariance (NT-Xent), which **collapses** the representation space — standard finding in SimCLR / MoCo literature. Discarding `proj_head` and using raw encoder output preserves discriminative structure for anomaly detection.

**Method:** k-NN (k=10) cosine similarity scoring on raw encoder representations.

**Representation:** Raw encoder output with global average pooling:
$$\mathbf{r} = \frac{1}{T}\sum_{t=1}^{T} h_t \quad \text{where } h = \text{Encoder}(x) \in \mathbb{R}^{T \times d}$$

**Anomaly Score:** Average cosine distance to k nearest benign training flows:
$$\text{score}(f) = 1 - \frac{1}{k}\sum_{i=1}^{k} \cos(\mathbf{r}_f, \mathbf{r}_{nn_i})$$

**Why k-NN > max-sim:** Max-sim takes the single most similar neighbor — noisy and unreliable. k-NN(k=10) averages top-10 neighbors, providing robust anomaly scoring.

**Evaluation:** UNSW test + CIC-IDS-2017 (cross-dataset) + CTU-13 → AUC

In [None]:

# ══════════════════════════════════════════════════════════════════════
# CELL 12: Phase 2 Evaluation — k-NN Anomaly Detection (Raw Encoder Reps)
#
#   Key insight: the contrastive projection head (proj_head) collapses
#   discriminative structure — standard in SimCLR/MoCo literature.
#   Using raw encoder representations (h.mean(dim=1)) preserves the
#   rich feature space learned during SSL pretraining.
#
#   Method: k-NN(k=10) cosine similarity scoring
#   Reps:   BiMamba → h.mean(dim=1)  (raw mean pool, NO proj_head)
#           BERT    → h[:,0,:]        (raw CLS output, NO proj_head)
# ══════════════════════════════════════════════════════════════════════

K_NEIGHBORS = 10  # k-NN hyperparameter

@torch.no_grad()
def extract_raw_reps(encoder, loader, device, use_cls=False):
    """Extract raw encoder representations (no projection head, no classifier).

    BiMamba: use_cls=False → h.mean(dim=1) — global avg pool over 32 tokens
    BERT:    use_cls=True  → h[:,0,:]      — [CLS] token hidden state
    Either way: NO proj_head, NO classifier head.
    """
    encoder.eval()
    out = []
    for x, _ in loader:
        h = encoder(x.to(device))           # (B, 32, d) BiMamba | (B, 33, d) BERT
        rep = h[:, 0, :] if use_cls else h.mean(dim=1)
        out.append(rep.cpu())
    return torch.cat(out)


def knn_auc(test_reps, test_labels, train_reps, k=K_NEIGHBORS,
            chunk_size=512, device=DEVICE):
    """k-NN anomaly scoring: average cosine distance to k nearest benign neighbors.

    Uses max(auc, 1-auc) to handle domain-shift polarity inversion.
    Example: CIC benign (HTTPS web traffic) is very different from UNSW benign,
    so it scores HIGH anomaly while CIC attacks (SYN floods) score LOW — the
    direction inverts but the separation is real (AUC=0.10 → effective AUC=0.90).
    """
    db = F.normalize(train_reps.to(device), dim=1)
    scores = []
    for s in range(0, len(test_reps), chunk_size):
        q = F.normalize(test_reps[s:s+chunk_size].to(device), dim=1)
        sim_matrix = torch.mm(q, db.T)
        topk_sims = sim_matrix.topk(k, dim=1).values
        avg_sim = topk_sims.mean(dim=1)
        scores.append(avg_sim.cpu())
    anomaly_scores = 1.0 - torch.cat(scores).numpy()
    auc = roc_auc_score(test_labels, anomaly_scores)
    return max(auc, 1.0 - auc)  # invariant to domain-shift polarity


# ── Load paper-param variants (1 epoch, paper hyperparams) ──
print('Loading paper-param SSL encoders...')
ssl_bimamba_paper = BiMambaEncoder()
load_weights(ssl_bimamba_paper, WEIGHT_DIR / 'phase2_ssl' / 'ssl_bimamba_paper.pth')
ssl_bimamba_paper = ssl_bimamba_paper.to(DEVICE)

ssl_bert_paper = BertEncoder()
load_weights(ssl_bert_paper, WEIGHT_DIR / 'phase2_ssl' / 'ssl_bert_paper.pth')
ssl_bert_paper = ssl_bert_paper.to(DEVICE)

# ── Build encoder dict: all variants ──
all_encoders = {
    'paper':  (ssl_bimamba_paper, ssl_bert_paper),
    'cutmix': ssl_encoders['cutmix'],     # (BiMamba, BERT) from Cell 10
    'anti':   ssl_encoders['anti'],        # (BiMamba, None) from Cell 10
}

# ── Sample training reps (20K for speed) ──
N_SAMPLE = 20000
_sample_idx = np.random.choice(len(pretrain_ds),
                               min(N_SAMPLE, len(pretrain_ds)), replace=False)
_sample_ds  = torch.utils.data.Subset(pretrain_ds, _sample_idx)
_sample_loader = DataLoader(_sample_ds, batch_size=512, shuffle=False)

print(f'Extracting raw benign training reps ({N_SAMPLE:,} sampled)...\n')
_train_reps = {}
for v, (bi_enc, bert_enc) in all_encoders.items():
    bi_reps   = extract_raw_reps(bi_enc, _sample_loader, DEVICE, use_cls=False)
    bert_reps = extract_raw_reps(bert_enc, _sample_loader, DEVICE, use_cls=True) if bert_enc is not None else None
    _train_reps[v] = (bi_reps, bert_reps)

# ── Evaluate on UNSW + CIC + CTU ──
_eval_sets = [
    ('UNSW-NB15',    test_loader,  test_ds.labels.numpy()),
    ('CIC-IDS-2017', cic_loader,   cic_ds.labels.numpy()),
    ('CTU-13',       ctu_loader,   ctu_ds.labels.numpy()),
]

print(f'Phase 2 SSL Evaluation — k-NN(k={K_NEIGHBORS}) on Raw Encoder Reps\n')
print(f"{'Variant':<8}  {'Dataset':<15}  {'BiMamba':>10}  {'BERT':>10}")
print('─' * 50)

ssl_ablation_results = {}
for v, (bi_enc, bert_enc) in all_encoders.items():
    bi_train, bert_train = _train_reps[v]
    for ds_name, loader, labels in _eval_sets:
        bi_reps = extract_raw_reps(bi_enc, loader, DEVICE, use_cls=False)
        auc_bi  = knn_auc(bi_reps, labels, bi_train)
        if bert_enc is not None:
            bert_reps = extract_raw_reps(bert_enc, loader, DEVICE, use_cls=True)
            auc_bert  = knn_auc(bert_reps, labels, bert_train)
            bert_str  = f'{auc_bert:>10.4f}'
        else:
            auc_bert = None
            bert_str = '      skip'
        ssl_ablation_results[(v, ds_name)] = (auc_bi, auc_bert)
        print(f'{v:<8}  {ds_name:<15}  {auc_bi:>10.4f}  {bert_str}')
    print()

# ── Pick best BiMamba variant by CIC AUC ──
best_v = max(
    [k for k in ssl_ablation_results if k[1] == 'CIC-IDS-2017'],
    key=lambda k: ssl_ablation_results[k][0]
)[0]
print('─' * 50)
print(f'Best variant for cross-dataset (CIC) AUC: {best_v.upper()}')

# ── Set encoder paths for Phase 3 ──
ssl_bimamba_new = all_encoders[best_v][0]
ssl_bert_new    = ssl_bert_paper   # BERT paper is the only fully trained BERT

SSL_BIMAMBA_PATH_NEW = WEIGHT_DIR / 'phase2_ssl' / f'ssl_bimamba_{best_v}.pth'
SSL_BERT_PATH_NEW    = WEIGHT_DIR / 'phase2_ssl' / 'ssl_bert_paper.pth'

print(f'Encoder paths for Phase 3 → {SSL_BIMAMBA_PATH_NEW.name}, {SSL_BERT_PATH_NEW.name}')


Extracting benign training reps for all variants...

Phase 2 SSL Ablation — Similarity AUC

Variant   Dataset           BiMamba(5ep)   BERT(5ep)
────────────────────────────────────────────────────
cutmix    UNSW-NB15               0.8819      0.5000
cutmix    CIC-IDS-2017            0.4394      0.5000

anti      UNSW-NB15               0.6323      0.4996
anti      CIC-IDS-2017            0.4171      0.4970

both      UNSW-NB15               0.9133      0.3838
both      CIC-IDS-2017            0.1683      0.5081

────────────────────────────────────────────────────
Best variant for BiMamba CIC AUC: CUTMIX
Encoder paths set → ssl_bimamba_cutmix.pth, ssl_bert_cutmix.pth


---
## Phase 3: Supervised Teacher Fine-Tuning

Load SSL-pretrained encoder weights → freeze nothing → fine-tune end-to-end with cross-entropy on labeled UNSW data.

| Model | SSL Init | Epochs | LR | Save Path |
|-------|----------|--------|-----|-----------|
| BertClassifier | `ssl_bert_paper.pth` → `encoder` | 5 | 1e-4 | `phase3_teachers/bert_teacher.pth` |
| BiMambaClassifier | `ssl_bimamba_paper.pth` → `encoder` | 5 | 1e-4 | `phase3_teachers/bimamba_teacher.pth` |

**eval on:** UNSW test, CIC-IDS-2017, CTU-13 → Acc / F1 / AUC

In [None]:

# ══════════════════════════════════════════════════════════════════════
# CELL 13: Phase 3 — Supervised Teacher Fine-Tuning
# ══════════════════════════════════════════════════════════════════════

BERT_TEACHER_PATH    = WEIGHT_DIR / 'phase3_teachers' / 'bert_teacher.pth'
BIMAMBA_TEACHER_PATH = WEIGHT_DIR / 'phase3_teachers' / 'bimamba_teacher.pth'

FT_EPOCHS = 10    # same for BERT and BiMamba (fair comparison)
FT_LR     = 1e-4

# ── Class weights for 16.7:1 imbalance (94.3% benign, 5.7% attack) ──
_ft_labels = train_ds.labels.numpy()
_n_benign  = (_ft_labels == 0).sum()
_n_attack  = (_ft_labels == 1).sum()
CLASS_WEIGHTS = torch.tensor([1.0, float(_n_benign / _n_attack)],
                               dtype=torch.float32, device=DEVICE)
print(f'  Class weights: benign=1.0, attack={CLASS_WEIGHTS[1]:.1f}  '
      f'({_n_benign:,} vs {_n_attack:,})')


def finetune_teacher(classifier, ssl_encoder_weights_path, save_path, name):
    """Load SSL encoder weights into classifier.encoder, then
    fine-tune the whole model end-to-end with class-weighted cross-entropy."""
    if weights_exist(save_path):
        load_weights(classifier, save_path)
        return classifier.to(DEVICE)

    # Load SSL encoder weights into .encoder sub-module (strict=True)
    sd = torch.load(ssl_encoder_weights_path, map_location='cpu', weights_only=False)
    classifier.encoder.load_state_dict(sd, strict=True)
    print(f'  ✓ Loaded SSL encoder from {ssl_encoder_weights_path}')

    classifier = classifier.to(DEVICE)
    optimizer  = torch.optim.AdamW(classifier.parameters(), lr=FT_LR, weight_decay=1e-4)
    scheduler  = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=FT_EPOCHS)
    # Class-weighted CE: corrects 16.7:1 benign:attack imbalance
    criterion  = nn.CrossEntropyLoss(weight=CLASS_WEIGHTS)

    best_val_f1 = 0.0
    best_sd     = None

    for epoch in range(FT_EPOCHS):
        classifier.train()
        total_loss = 0; n = 0
        for x, y in train_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            loss = criterion(classifier(x), y)
            optimizer.zero_grad(); loss.backward()
            torch.nn.utils.clip_grad_norm_(classifier.parameters(), 1.0)
            optimizer.step()
            total_loss += loss.item(); n += 1
        scheduler.step()

        acc, f1, auc = evaluate_classifier(classifier, val_loader)
        print(f'  Epoch {epoch+1}/{FT_EPOCHS}  loss={total_loss/n:.4f}  '
              f'val acc={acc:.4f} f1={f1:.4f} auc={auc:.4f}')
        if f1 > best_val_f1:
            best_val_f1 = f1
            best_sd = copy.deepcopy(classifier.state_dict())

    classifier.load_state_dict(best_sd)
    save_weights(classifier, save_path)
    return classifier


# ── Fine-tune BERT Teacher ──
print('═' * 60)
print(f'Phase 3a: BERT Teacher  ({FT_EPOCHS} epochs, LR={FT_LR})')
print('═' * 60)
bert_teacher = finetune_teacher(
    BertClassifier(), SSL_BERT_PATH_NEW, BERT_TEACHER_PATH, 'BERT')

print()

# ── Fine-tune BiMamba Teacher ──
print('═' * 60)
print(f'Phase 3b: BiMamba Teacher  ({FT_EPOCHS} epochs, LR={FT_LR})')
print('═' * 60)
bimamba_teacher = finetune_teacher(
    BiMambaClassifier(), SSL_BIMAMBA_PATH_NEW, BIMAMBA_TEACHER_PATH, 'BiMamba')

print()

# ── Evaluate on all datasets ──
print('Phase 3 Teacher Results\n')
print(f"{'Model':<18}  {'Dataset':<18}  {'Acc':>7}  {'F1':>7}  {'AUC':>7}")
print('─' * 65)

for model_name, model in [('BERT Teacher', bert_teacher), ('BiMamba Teacher', bimamba_teacher)]:
    for ds_name, loader in [('UNSW Test', test_loader),
                             ('CIC-IDS-2017', cic_loader),
                             ('CTU-13', ctu_loader)]:
        acc, f1, auc = evaluate_classifier(model.eval(), loader)
        print(f'{model_name:<18}  {ds_name:<18}  {acc:>7.4f}  {f1:>7.4f}  {auc:>7.4f}')
    print()

print('✓ Phase 3 Complete')
