"""
# MambaShield: Resilient Temporally-Aware Intrusion Detection Based on Selective State Space Models
## Authors: Roger Nick Anaedevha, Alexander Gennadevich Trofimov, Yuri Vladimirovich Borodachev
## Paper: IEEE Transactions on Artificial Intelligence (TAI)

### Complete implementation with all methodologies, algorithms, and evaluation metrics.
### Optimized for GPU memory efficiency and production deployment.

Citation:
@article{anaedevha2025mambashield,
  title={MambaShield: Resilient Temporally-Aware Intrusion Detection Based on Selective State Space Models},
  author={Anaedevha, Roger Nick and Trofimov, Alexander Gennadevich and Borodachev, Yuri Vladimirovich},
  journal={IEEE Transactions on Artificial Intelligence},
  year={2025}
}
"""


# ======================================================================================================
# IMPORTING LIBRARIES
# ======================================================================================================

In [1]:

import os
import sys
import gc
import time
import math
import json
import random
import warnings
import traceback
from datetime import datetime
from typing import Dict, List, Tuple, Optional, Union, Any, Callable
from collections import defaultdict, deque
from dataclasses import dataclass, field

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader, TensorDataset, Dataset, ConcatDataset
from torch.nn.utils import clip_grad_norm_

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Machine Learning
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder, RobustScaler
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    confusion_matrix, roc_auc_score, roc_curve, auc,
    average_precision_score, matthews_corrcoef, cohen_kappa_score,
    classification_report, balanced_accuracy_score,
    brier_score_loss
)
from sklearn.calibration import calibration_curve  # or CalibrationDisplay

from sklearn.utils import shuffle

# Suppress warnings
warnings.filterwarnings('ignore')
plt.style.use('seaborn-v0_8-darkgrid')



# ============================================================================
# CONFIGURATION AND HYPERPARAMETERS
# ============================================================================

In [None]:


@dataclass
class MambaShieldConfig:
    """Configuration for MambaShield model as per paper specifications"""
    # Model Architecture (Section III.B)
    input_dim: int = 84  # Unified feature dimension
    hidden_dim: int = 256  # D in paper
    d_state: int = 16  # N in paper
    d_conv: int = 4  # Convolution kernel size
    n_layers: int = 3  # Number of Mamba blocks
    n_scales: int = 4  # Multi-scale temporal aggregation
    dropout: float = 0.1
    
    # Training Configuration (Section V.A.3)
    batch_size: int = 32
    learning_rate: float = 1e-4
    weight_decay: float = 1e-5
    epochs: int = 100
    early_stopping_patience: int = 10
    gradient_clip: float = 1.0
    
    # PARD Configuration (Section III.C)
    n_teachers: int = 4
    distill_temperature: float = 3.0
    distill_beta: float = 0.1  # Progressive schedule parameter
    distill_epochs: int = 20
    
    # Hierarchical RL (Section III.D)
    rl_learning_rate_high: float = 5e-4
    rl_learning_rate_low: float = 1e-3
    rl_gamma: float = 0.99
    rl_c51_atoms: int = 51
    rl_v_min: float = -10.0
    rl_v_max: float = 10.0
    
    # PAC-Bayes (Section III.E)
    pac_bayes_prior_variance: float = 1.0
    pac_bayes_delta: float = 0.05
    
    # Poisoning Attack Parameters (Section V.A.2)
    poison_epsilon: float = 0.1
    poison_tau: float = 0.05  # Temporal coherence
    poison_rate: float = 0.2
    poison_iterations: int = 10
    
    # Memory Management
    use_mixed_precision: bool = True
    memory_threshold: float = 0.8
    checkpoint_interval: int = 5
    
    # Evaluation
    n_metrics: int = 23  # Comprehensive evaluation framework
    cross_validation_folds: int = 5
    test_size: float = 0.2
    val_size: float = 0.2
    random_seed: int = 42




# ============================================================================
# MEMORY MANAGEMENT
# ============================================================================

In [None]:


class MemoryManager:
    """Advanced GPU memory management for efficient training"""
    
    def __init__(self, device: torch.device, threshold: float = 0.8):
        self.device = device
        self.threshold = threshold
        self.memory_stats = defaultdict(list)
        
    def get_memory_info(self) -> Dict[str, float]:
        """Get current GPU memory statistics"""
        if not torch.cuda.is_available():
            return {}
            
        return {
            'allocated_gb': torch.cuda.memory_allocated(self.device) / 1e9,
            'reserved_gb': torch.cuda.memory_reserved(self.device) / 1e9,
            'max_allocated_gb': torch.cuda.max_memory_allocated(self.device) / 1e9,
            'free_gb': (torch.cuda.get_device_properties(self.device).total_memory - 
                       torch.cuda.memory_allocated(self.device)) / 1e9
        }
    
    def clear_cache(self):
        """Aggressive memory clearing"""
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize(self.device)
    
    def check_memory(self) -> bool:
        """Check if memory usage is below threshold"""
        info = self.get_memory_info()
        if not info:
            return True
            
        total_memory = torch.cuda.get_device_properties(self.device).total_memory / 1e9
        usage_ratio = info['allocated_gb'] / total_memory
        
        if usage_ratio > self.threshold:
            self.clear_cache()
            return False
        return True
    
    def optimize_batch_size(self, model: nn.Module, input_shape: Tuple, 
                          initial_batch: int = 64) -> int:
        """Find optimal batch size for available memory"""
        batch_size = initial_batch
        model.eval()
        
        while batch_size > 1:
            try:
                self.clear_cache()
                dummy_input = torch.randn(batch_size, *input_shape).to(self.device)
                with torch.no_grad():
                    _ = model(dummy_input)
                
                info = self.get_memory_info()
                if info and info['allocated_gb'] / torch.cuda.get_device_properties(self.device).total_memory < 0.7:
                    return batch_size
                    
            except RuntimeError as e:
                if "out of memory" in str(e):
                    self.clear_cache()
                    
            batch_size = batch_size // 2
            
        return max(batch_size, 1)




# ============================================================================
# SELECTIVE STATE SPACE MODEL (MAMBA) - Algorithm 1 from Paper
# ============================================================================

In [None]:


class SelectiveSSM(nn.Module):
    """
    Selective State Space Model implementation based on Mamba
    Reference: Section III.B of the paper, Algorithm 1
    """
    
    def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4, 
                 expand_factor: int = 2, dt_rank: Optional[int] = None):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expand = expand_factor
        self.d_inner = int(self.expand * d_model)
        
        # Compute dt_rank as per paper Eq. 20
        if dt_rank is None:
            self.dt_rank = math.ceil(d_model / 16)
        else:
            self.dt_rank = dt_rank
            
        # Input projection
        self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
        
        # Convolution as per paper
        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            kernel_size=d_conv,
            padding=d_conv - 1,
            groups=self.d_inner
        )
        
        # SSM parameters projection
        self.x_proj = nn.Linear(self.d_inner, self.dt_rank + d_state * 2, bias=False)
        self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
        
        # Initialize dt bias as per paper
        dt_init_std = self.dt_rank ** -0.5
        nn.init.uniform_(self.dt_proj.bias, -dt_init_std, dt_init_std)
        
        # State matrix A (Equation 23 in paper)
        A = self._init_state_matrix()
        self.A_log = nn.Parameter(torch.log(A))
        
        # D parameter for skip connection
        self.D = nn.Parameter(torch.ones(self.d_inner))
        
        # Output projection
        self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
        
    def _init_state_matrix(self) -> torch.Tensor:
        """Initialize state matrix A using HiPPO initialization"""
        A = torch.arange(1, self.d_state + 1).reshape(1, -1).repeat(self.d_inner, 1)
        return A.float()
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass implementing Algorithm 1 from the paper
        Args:
            x: Input tensor of shape (B, L, D)
        Returns:
            Output tensor of shape (B, L, D)
        """
        batch, seq_len, _ = x.shape
        
        # Dual branch projection (Eq. 19)
        xz = self.in_proj(x)  # (B, L, 2*D_inner)
        x, z = xz.chunk(2, dim=-1)  # Each (B, L, D_inner)
        
        # Convolution branch
        x = x.transpose(1, 2)  # (B, D_inner, L)
        x = self.conv1d(x)[:, :, :seq_len]  # Ensure correct length
        x = x.transpose(1, 2)  # (B, L, D_inner)
        
        # Apply SiLU activation
        x = F.silu(x)
        
        # SSM computation
        y = self.selective_scan(x)
        
        # Gating mechanism (Hadamard product)
        y = y * F.silu(z)
        
        # Output projection
        output = self.out_proj(y)
        
        return output
    
    def selective_scan(self, u: torch.Tensor) -> torch.Tensor:
        """
        Implements the selective scan algorithm (Algorithm 1)
        This is the core of the Mamba model
        """
        batch, seq_len, d_inner = u.shape
        
        # Compute ∆, B, C from input (Equations 20-22)
        deltaBC = self.x_proj(u)  # (B, L, dt_rank + 2*d_state)
        
        delta, B, C = torch.split(
            deltaBC, 
            [self.dt_rank, self.d_state, self.d_state], 
            dim=-1
        )
        
        # Transform delta (Equation 20)
        delta = F.softplus(self.dt_proj(delta))  # (B, L, D_inner)
        
        # Get state matrix
        A = -torch.exp(self.A_log.float())  # (D_inner, d_state)
        
        # Discretization (Zero-Order Hold) - Equations 23-24
        deltaA = torch.exp(delta.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0))  # (B, L, D_inner, d_state)
        deltaB = delta.unsqueeze(-1) * B.unsqueeze(2)  # (B, L, D_inner, d_state)
        
        # Selective scan using parallel algorithm
        y = self.parallel_scan(u, deltaA, deltaB, C)
        
        # Add skip connection with D parameter
        y = y + u * self.D
        
        return y
    
    def parallel_scan(self, u: torch.Tensor, deltaA: torch.Tensor, 
                     deltaB: torch.Tensor, C: torch.Tensor) -> torch.Tensor:
        """
        Parallel scan implementation for O(log L) complexity
        Reference: Algorithm 1, lines 8-10
        """
        batch, seq_len, d_inner = u.shape
        d_state = deltaA.shape[-1]
        
        # Initialize output and hidden state
        y = torch.zeros_like(u)
        h = torch.zeros(batch, d_inner, d_state, device=u.device, dtype=u.dtype)
        
        # Sequential scan (can be optimized with parallel prefix sum)
        for t in range(seq_len):
            h = deltaA[:, t] * h + deltaB[:, t] * u[:, t].unsqueeze(-1)
            y[:, t] = (h * C[:, t].unsqueeze(1)).sum(dim=-1)
            
        return y




# ============================================================================
# MAMBA BLOCK WITH LAYER NORM AND RESIDUAL
# ============================================================================


In [None]:

class MambaBlock(nn.Module):
    """
    Complete Mamba block with normalization and residual connection
    Reference: Section III.B of the paper
    """
    
    def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4,
                 expand_factor: int = 2, dropout: float = 0.1):
        super().__init__()
        
        self.norm = nn.LayerNorm(d_model)
        self.ssm = SelectiveSSM(d_model, d_state, d_conv, expand_factor)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass with residual connection
        """
        residual = x
        x = self.norm(x)
        x = self.ssm(x)
        x = self.dropout(x)
        x = x + residual
        return x




# ============================================================================
# MULTI-SCALE TEMPORAL AGGREGATION
# ============================================================================

In [None]:


class MultiScaleTemporalAggregation(nn.Module):
    """
    Multi-scale temporal aggregation module
    Reference: Section III.B.4, Equations 25-27
    """
    
    def __init__(self, d_model: int, n_scales: int = 4):
        super().__init__()
        self.n_scales = n_scales
        self.d_model = d_model
        
        # Create SSM for each scale
        self.scale_ssms = nn.ModuleList([
            SelectiveSSM(d_model, d_state=16, d_conv=2**(s+1))
            for s in range(n_scales)
        ])
        
        # Attention for aggregation
        self.scale_attention = nn.Linear(d_model * n_scales, n_scales)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Process input at multiple temporal scales
        """
        batch, seq_len, d_model = x.shape
        
        # Process at different scales
        scale_outputs = []
        for s, ssm in enumerate(self.scale_ssms):
            # Downsample for higher scales
            stride = 2 ** s
            if stride > 1:
                x_scaled = x[:, ::stride, :]
            else:
                x_scaled = x
                
            # Process through SSM
            h_s = ssm(x_scaled)
            
            # Upsample back to original length
            if stride > 1:
                h_s = F.interpolate(
                    h_s.transpose(1, 2),
                    size=seq_len,
                    mode='linear',
                    align_corners=False
                ).transpose(1, 2)
                
            scale_outputs.append(h_s)
        
        # Concatenate all scales
        h_concat = torch.cat(scale_outputs, dim=-1)  # (B, L, D*n_scales)
        
        # Compute attention weights (Equation 26)
        alpha = F.softmax(self.scale_attention(h_concat), dim=-1)  # (B, L, n_scales)
        
        # Weighted aggregation (Equation 27)
        h_agg = torch.zeros_like(x)
        for s in range(self.n_scales):
            h_agg += alpha[:, :, s:s+1] * scale_outputs[s]
            
        return h_agg




# ============================================================================
# PROGRESSIVE ADVERSARIAL ROBUSTNESS DISTILLATION (PARD) - Algorithm 2
# ============================================================================

In [None]:


class ProgressiveARD(nn.Module):
    """
    Progressive Adversarial Robustness Distillation
    Reference: Section III.C, Algorithm 2
    """
    
    def __init__(self, student: nn.Module, config: MambaShieldConfig):
        super().__init__()
        self.student = student
        self.config = config
        self.teachers = nn.ModuleList()
        self.temperature = config.distill_temperature
        self.beta = config.distill_beta
        
    def create_teacher(self, attack_type: str) -> nn.Module:
        """Create specialized teacher for specific attack type"""
        teacher = MambaShieldModel(self.config)
        # Teacher will be trained separately on specific attack type
        return teacher
    
    def distillation_loss(self, student_logits: torch.Tensor, 
                         teacher_logits: torch.Tensor) -> torch.Tensor:
        """
        Compute knowledge distillation loss
        Reference: Equation 17 in related work
        """
        student_soft = F.log_softmax(student_logits / self.temperature, dim=-1)
        teacher_soft = F.softmax(teacher_logits / self.temperature, dim=-1)
        
        loss = F.kl_div(student_soft, teacher_soft, reduction='batchmean')
        return loss * (self.temperature ** 2)
    
    def progressive_schedule(self, epoch: int) -> float:
        """
        Progressive weighting schedule α(t)
        Reference: Algorithm 2, line 16
        """
        return 1 - math.exp(-self.beta * epoch)
    
    def forward(self, x: torch.Tensor, labels: torch.Tensor, 
                epoch: int = 0) -> Dict[str, torch.Tensor]:
        """
        Forward pass with progressive distillation
        """
        # Student prediction
        student_output = self.student(x)
        
        if not self.training or len(self.teachers) == 0:
            return {'logits': student_output, 'distill_loss': torch.tensor(0.0)}
        
        # Compute alpha for current epoch
        alpha = self.progressive_schedule(epoch)
        
        # Task loss
        task_loss = F.cross_entropy(student_output, labels)
        
        # Distillation loss from ensemble of teachers
        distill_loss = 0
        teacher_confidences = []
        
        with torch.no_grad():
            for teacher in self.teachers:
                teacher.eval()
                teacher_logits = teacher(x)
                
                # Compute teacher confidence (Algorithm 2, line 21)
                confidence = torch.max(F.softmax(teacher_logits, dim=-1), dim=-1)[0]
                teacher_confidences.append(confidence.mean().item())
                
                # Weighted distillation loss
                distill_loss += self.distillation_loss(student_output, teacher_logits)
        
        # Average distillation loss
        if len(self.teachers) > 0:
            distill_loss /= len(self.teachers)
        
        # Combined loss (Algorithm 2, line 25)
        total_loss = alpha * task_loss + (1 - alpha) * distill_loss
        
        return {
            'logits': student_output,
            'task_loss': task_loss,
            'distill_loss': distill_loss,
            'total_loss': total_loss,
            'alpha': alpha
        }




# ============================================================================
# HIERARCHICAL REINFORCEMENT LEARNING MODULE
# ============================================================================

In [None]:


class HierarchicalRL(nn.Module):
    """
    Hierarchical Reinforcement Learning for adaptive decision making
    Reference: Section III.D, Equations 28-30
    """
    
    def __init__(self, state_dim: int, config: MambaShieldConfig):
        super().__init__()
        self.config = config
        
        # High-level strategic policy (Equation 28)
        self.strategic_policy = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 4)  # {Monitor, Investigate, Block, Adapt}
        )
        
        # Low-level tactical policy (Equation 29)
        self.tactical_policy = nn.Sequential(
            nn.Linear(state_dim + 4, 256),  # State + strategic action
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 10)  # Specific tactical actions
        )
        
        # C51 distributional value estimation (Equation 30)
        self.n_atoms = config.rl_c51_atoms
        self.v_min = config.rl_v_min
        self.v_max = config.rl_v_max
        self.delta_z = (self.v_max - self.v_min) / (self.n_atoms - 1)
        self.support = torch.linspace(self.v_min, self.v_max, self.n_atoms)
        
        self.value_network = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, self.n_atoms)
        )
        
    def forward(self, state: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        Hierarchical decision making
        """
        # Strategic decision
        strategic_logits = self.strategic_policy(state)
        strategic_action = F.softmax(strategic_logits, dim=-1)
        
        # Tactical decision conditioned on strategic
        tactical_input = torch.cat([state, strategic_action], dim=-1)
        tactical_logits = self.tactical_policy(tactical_input)
        tactical_action = F.softmax(tactical_logits, dim=-1)
        
        # Value distribution
        value_logits = self.value_network(state)
        value_dist = F.softmax(value_logits, dim=-1)
        
        # Expected value
        if self.support.device != state.device:
            self.support = self.support.to(state.device)
        expected_value = (value_dist * self.support).sum(dim=-1)
        
        return {
            'strategic': strategic_action,
            'tactical': tactical_action,
            'value_dist': value_dist,
            'expected_value': expected_value
        }




# ============================================================================
# PAC-BAYES REGULARIZATION
# ============================================================================

In [None]:


class PACBayesRegularizer(nn.Module):
    """
    PAC-Bayes regularization for certified robustness
    Reference: Section III.E, Theorem 3, Equation 34
    """
    
    def __init__(self, model: nn.Module, config: MambaShieldConfig):
        super().__init__()
        self.model = model
        self.prior_variance = config.pac_bayes_prior_variance
        self.delta = config.pac_bayes_delta
        
    def compute_kl_divergence(self) -> torch.Tensor:
        """
        Compute KL divergence between posterior and prior
        Reference: Section III.E
        """
        kl = 0.0
        for param in self.model.parameters():
            if param.requires_grad:
                # Assume Gaussian prior centered at 0
                kl += 0.5 * torch.sum(
                    (param ** 2) / self.prior_variance - 
                    torch.log(torch.ones_like(param) * self.prior_variance) - 1
                )
        return kl
    
    def pac_bayes_bound(self, empirical_risk: torch.Tensor, 
                       n_samples: int) -> Dict[str, float]:
        """
        Compute PAC-Bayes generalization bound
        Reference: Theorem 3, Equation 34
        """
        kl = self.compute_kl_divergence()
        
        # McAllester's bound
        complexity = torch.sqrt(
            (kl + math.log(2 * math.sqrt(n_samples) / self.delta)) / 
            (2 * n_samples)
        )
        
        bound = empirical_risk + complexity
        
        return {
            'empirical_risk': empirical_risk.item(),
            'kl_divergence': kl.item(),
            'complexity': complexity.item(),
            'bound': bound.item()
        }




# ============================================================================
# COMPLETE MAMBASHIELD MODEL
# ============================================================================

In [None]:


class MambaShieldModel(nn.Module):
    """
    Complete MambaShield architecture
    Reference: Figure 1 and Section III
    """
    
    def __init__(self, config: MambaShieldConfig):
        super().__init__()
        self.config = config
        
        # Input preprocessing (84 unified features)
        self.input_proj = nn.Sequential(
            nn.Linear(config.input_dim, config.hidden_dim),
            nn.LayerNorm(config.hidden_dim),
            nn.ReLU(),
            nn.Dropout(config.dropout)
        )
        
        # Mamba blocks (Section III.B)
        self.mamba_blocks = nn.ModuleList([
            MambaBlock(
                config.hidden_dim, 
                config.d_state, 
                config.d_conv,
                dropout=config.dropout
            )
            for _ in range(config.n_layers)
        ])
        
        # Multi-scale temporal aggregation
        self.multi_scale = MultiScaleTemporalAggregation(
            config.hidden_dim, 
            config.n_scales
        )
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.LayerNorm(config.hidden_dim),
            nn.Linear(config.hidden_dim, config.hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.hidden_dim // 2, 7)  # 7 unified attack categories
        )
        
        # Hierarchical RL component
        self.rl_module = HierarchicalRL(config.hidden_dim, config)
        
    def forward(self, x: torch.Tensor, 
                return_features: bool = False) -> Union[torch.Tensor, Dict]:
        """
        Forward pass through MambaShield
        Args:
            x: Input tensor of shape (B, L, 84)
            return_features: Whether to return intermediate features
        """
        batch, seq_len, _ = x.shape
        
        # Input projection
        x = self.input_proj(x)
        
        # Process through Mamba blocks
        for mamba_block in self.mamba_blocks:
            x = mamba_block(x)
        
        # Multi-scale aggregation
        x = self.multi_scale(x)
        
        # Use last timestep for classification
        features = x[:, -1, :]
        
        # Classification
        logits = self.classifier(features)
        
        if return_features:
            # RL decision
            rl_output = self.rl_module(features)
            
            return {
                'logits': logits,
                'features': features,
                'rl_decision': rl_output
            }
        
        return logits




# ============================================================================
# POISONING ATTACK IMPLEMENTATIONS
# ============================================================================

In [None]:


class PoisoningAttackSimulator:
    """
    Implementation of poisoning attacks from Section V.A.2
    """
    
    def __init__(self, config: MambaShieldConfig):
        self.config = config
        self.epsilon = config.poison_epsilon
        self.tau = config.poison_tau
        self.rate = config.poison_rate
        
    def gradient_based_poisoning(self, model: nn.Module, x: torch.Tensor, 
                                y: torch.Tensor) -> torch.Tensor:
        """
        Gradient-based poisoning (GBP)
        Reference: Equation 37
        """
        x_adv = x.clone().detach().requires_grad_(True)
        
        # Forward pass
        outputs = model(x_adv)
        loss = F.cross_entropy(outputs, y)
        
        # Compute gradients
        loss.backward()
        
        # Generate adversarial perturbation
        perturbation = self.epsilon * x_adv.grad.sign()
        
        # Apply temporal coherence constraint
        if len(perturbation.shape) == 3:  # (B, L, D)
            for t in range(1, perturbation.shape[1]):
                delta = perturbation[:, t] - perturbation[:, t-1]
                if torch.norm(delta) > self.tau:
                    perturbation[:, t] = perturbation[:, t-1] + \
                                       self.tau * delta / torch.norm(delta)
        
        x_poisoned = x + perturbation
        return x_poisoned.detach()
    
    def label_flipping(self, y: torch.Tensor, num_classes: int) -> torch.Tensor:
        """
        Label flipping attack
        Reference: Section V.A.2.2
        """
        y_flipped = y.clone()
        n_flip = int(len(y) * self.rate)
        flip_indices = torch.randperm(len(y))[:n_flip]
        
        for idx in flip_indices:
            current = y_flipped[idx].item()
            # Flip to different class
            new_label = random.choice([i for i in range(num_classes) if i != current])
            y_flipped[idx] = new_label
            
        return y_flipped
    
    def backdoor_attack(self, x: torch.Tensor, target_class: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Backdoor trigger attack
        Reference: Equation 38
        """
        x_backdoor = x.clone()
        n_poison = int(len(x) * self.rate)
        poison_indices = torch.randperm(len(x))[:n_poison]
        
        # Create trigger pattern (5% of features)
        trigger = torch.zeros_like(x[0])
        n_trigger_features = int(0.05 * x.shape[-1])
        trigger_positions = torch.randperm(x.shape[-1])[:n_trigger_features]
        trigger[:, trigger_positions] = 1.0
        
        # Apply trigger
        for idx in poison_indices:
            x_backdoor[idx] = x_backdoor[idx] * 0.95 + trigger * 0.05
            
        return x_backdoor, poison_indices
    
    def clean_label_poisoning(self, model: nn.Module, x: torch.Tensor, 
                             y: torch.Tensor) -> torch.Tensor:
        """
        Clean-label poisoning
        Reference: Equation 39
        """
        x_clean = x.clone()
        n_poison = int(len(x) * self.rate)
        poison_indices = torch.randperm(len(x))[:n_poison]
        
        for idx in poison_indices:
            x_curr = x[idx:idx+1].clone().requires_grad_(True)
            y_curr = y[idx:idx+1]
            
            # Optimize to change decision boundary while keeping label
            for _ in range(self.config.poison_iterations):
                outputs = model(x_curr)
                loss = -F.cross_entropy(outputs, y_curr)  # Maximize confusion
                loss.backward()
                
                # Update with constraint
                with torch.no_grad():
                    perturbation = self.epsilon * x_curr.grad.sign()
                    x_curr = x_curr + 0.1 * perturbation
                    x_curr = torch.clamp(x_curr, x[idx:idx+1] - self.epsilon, 
                                       x[idx:idx+1] + self.epsilon)
                    x_curr.requires_grad_(True)
                    
            x_clean[idx] = x_curr.detach()[0]
            
        return x_clean




# ============================================================================
# COMPREHENSIVE EVALUATION FRAMEWORK (23 METRICS)
# ============================================================================

In [None]:


class ComprehensiveEvaluator:
    """
    23-metric evaluation framework
    Reference: Section V.E, Table V
    """
    
    def __init__(self, config: MambaShieldConfig):
        self.config = config
        self.metrics = {}
        
    def evaluate(self, model: nn.Module, dataloader: DataLoader, 
                device: torch.device, attack_simulator: Optional[PoisoningAttackSimulator] = None) -> Dict:
        """
        Comprehensive evaluation with all 23 metrics
        """
        model.eval()
        
        all_preds = []
        all_labels = []
        all_probs = []
        all_features = []
        inference_times = []
        
        with torch.no_grad():
            for batch_idx, (data, labels) in enumerate(dataloader):
                data, labels = data.to(device), labels.to(device)
                
                # Measure inference time
                start_time = time.time()
                outputs = model(data, return_features=True)
                inference_time = time.time() - start_time
                inference_times.append(inference_time)
                
                if isinstance(outputs, dict):
                    logits = outputs['logits']
                    if 'features' in outputs:
                        all_features.append(outputs['features'].cpu())
                else:
                    logits = outputs
                
                probs = F.softmax(logits, dim=-1)
                preds = logits.argmax(dim=-1)
                
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                all_probs.extend(probs.cpu().numpy())
        
        all_preds = np.array(all_preds)
        all_labels = np.array(all_labels)
        all_probs = np.array(all_probs)
        
        # 1. Accuracy Metrics
        self.metrics['accuracy'] = accuracy_score(all_labels, all_preds)
        self.metrics['balanced_accuracy'] = balanced_accuracy_score(all_labels, all_preds)
        self.metrics['top_5_accuracy'] = self._top_k_accuracy(all_labels, all_probs, k=5)
        
        # 2. Robustness Metrics
        if attack_simulator:
            self.metrics['poisoning_resilience'] = self._evaluate_poisoning_resilience(
                model, dataloader, device, attack_simulator
            )
        
        # 3. Precision, Recall, F1
        self.metrics['precision'] = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
        self.metrics['recall'] = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
        self.metrics['f1_score'] = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
        
        # 4. ROC-AUC
        try:
            if len(np.unique(all_labels)) == 2:
                self.metrics['roc_auc'] = roc_auc_score(all_labels, all_probs[:, 1])
            else:
                self.metrics['roc_auc'] = roc_auc_score(all_labels, all_probs, multi_class='ovr')
        except:
            self.metrics['roc_auc'] = 0.0
        
        # 5. Matthews Correlation Coefficient
        self.metrics['mcc'] = matthews_corrcoef(all_labels, all_preds)
        
        # 6. Cohen's Kappa
        self.metrics['cohen_kappa'] = cohen_kappa_score(all_labels, all_preds)
        
        # 7. Confusion Matrix Metrics
        cm = confusion_matrix(all_labels, all_preds)
        self._compute_confusion_metrics(cm)
        
        # 8. Uncertainty Metrics
        self._compute_uncertainty_metrics(all_probs)
        
        # 9. Calibration Metrics
        self.metrics['brier_score'] = self._compute_brier_score(all_labels, all_probs)
        self.metrics['ece'] = self._compute_ece(all_labels, all_preds, all_probs)
        
        # 10. Temporal Metrics
        self.metrics['temporal_consistency'] = self._compute_temporal_consistency(all_preds)
        
        # 11. Efficiency Metrics
        self.metrics['avg_inference_time_ms'] = np.mean(inference_times) * 1000
        self.metrics['memory_usage_mb'] = self._get_memory_usage()
        self.metrics['flops'] = self._estimate_flops(model)
        
        # 12. Model Complexity
        self.metrics['num_parameters'] = sum(p.numel() for p in model.parameters())
        self.metrics['num_trainable_params'] = sum(p.numel() for p in model.parameters() if p.requires_grad)
        
        # 13. Feature Quality Metrics
        if all_features:
            features = torch.cat(all_features, dim=0).numpy()
            self.metrics['feature_variance'] = np.mean(np.var(features, axis=0))
            self.metrics['feature_sparsity'] = np.mean(features == 0)
        
        return self.metrics
    
    def _top_k_accuracy(self, labels: np.ndarray, probs: np.ndarray, k: int = 5) -> float:
        """Compute top-k accuracy"""
        n_classes = probs.shape[1]
        k = min(k, n_classes)
        
        top_k_preds = np.argsort(probs, axis=1)[:, -k:]
        correct = 0
        for i, label in enumerate(labels):
            if label in top_k_preds[i]:
                correct += 1
        return correct / len(labels)
    
    def _evaluate_poisoning_resilience(self, model: nn.Module, dataloader: DataLoader,
                                      device: torch.device, attack_simulator: PoisoningAttackSimulator) -> float:
        """Evaluate model resilience to poisoning attacks"""
        correct_before = 0
        correct_after = 0
        total = 0
        
        model.eval()
        with torch.no_grad():
            for data, labels in dataloader:
                data, labels = data.to(device), labels.to(device)
                
                # Clean accuracy
                outputs_clean = model(data)
                preds_clean = outputs_clean.argmax(dim=-1)
                correct_before += (preds_clean == labels).sum().item()
                
                # Poisoned accuracy
                data_poisoned = attack_simulator.gradient_based_poisoning(model, data, labels)
                outputs_poisoned = model(data_poisoned)
                preds_poisoned = outputs_poisoned.argmax(dim=-1)
                correct_after += (preds_poisoned == labels).sum().item()
                
                total += labels.size(0)
                
                if total >= 1000:  # Sample for efficiency
                    break
        
        resilience = correct_after / max(correct_before, 1)
        return resilience
    
    def _compute_confusion_metrics(self, cm: np.ndarray):
        """Compute metrics from confusion matrix"""
        # True Positive Rate (Sensitivity) per class
        tpr_per_class = np.diag(cm) / (cm.sum(axis=1) + 1e-10)
        self.metrics['mean_tpr'] = np.mean(tpr_per_class)
        
        # False Positive Rate
        fp = cm.sum(axis=0) - np.diag(cm)
        fn = cm.sum(axis=1) - np.diag(cm)
        tp = np.diag(cm)
        tn = cm.sum() - (fp + fn + tp)
        
        fpr_per_class = fp / (fp + tn + 1e-10)
        self.metrics['mean_fpr'] = np.mean(fpr_per_class)
        
        # False Negative Rate
        fnr_per_class = fn / (fn + tp + 1e-10)
        self.metrics['mean_fnr'] = np.mean(fnr_per_class)
    
    def _compute_uncertainty_metrics(self, probs: np.ndarray):
        """Compute uncertainty metrics"""
        # Predictive entropy
        entropy = -np.sum(probs * np.log(probs + 1e-10), axis=1)
        self.metrics['mean_entropy'] = np.mean(entropy)
        self.metrics['std_entropy'] = np.std(entropy)
        
        # Aleatoric and epistemic uncertainty (simplified)
        self.metrics['aleatoric_uncertainty'] = np.mean(entropy)
        self.metrics['epistemic_uncertainty'] = np.std(probs.max(axis=1))
    
    def _compute_brier_score(self, labels: np.ndarray, probs: np.ndarray) -> float:
        """Compute Brier score for calibration"""
        n_classes = probs.shape[1]
        brier = 0
        for i in range(len(labels)):
            label_onehot = np.zeros(n_classes)
            label_onehot[labels[i]] = 1
            brier += np.sum((probs[i] - label_onehot) ** 2)
        return brier / len(labels)
    
    def _compute_ece(self, labels: np.ndarray, preds: np.ndarray, 
                    probs: np.ndarray, n_bins: int = 10) -> float:
        """Expected Calibration Error"""
        max_probs = np.max(probs, axis=1)
        correct = (preds == labels).astype(float)
        
        ece = 0
        for bin_i in range(n_bins):
            bin_lower = bin_i / n_bins
            bin_upper = (bin_i + 1) / n_bins
            
            in_bin = (max_probs > bin_lower) & (max_probs <= bin_upper)
            if np.sum(in_bin) > 0:
                bin_acc = np.mean(correct[in_bin])
                bin_conf = np.mean(max_probs[in_bin])
                bin_size = np.sum(in_bin)
                ece += (bin_size / len(labels)) * abs(bin_acc - bin_conf)
        
        return ece
    
    def _compute_temporal_consistency(self, preds: np.ndarray) -> float:
        """Compute temporal consistency of predictions"""
        if len(preds) < 2:
            return 1.0
        
        consistency = 0
        for i in range(1, len(preds)):
            if preds[i] == preds[i-1]:
                consistency += 1
        return consistency / (len(preds) - 1)
    
    def _get_memory_usage(self) -> float:
        """Get current memory usage in MB"""
        if torch.cuda.is_available():
            return torch.cuda.memory_allocated() / 1e6
        return 0
    
    def _estimate_flops(self, model: nn.Module) -> int:
        """Estimate FLOPs for the model"""
        # Simplified estimation
        total_flops = 0
        for module in model.modules():
            if isinstance(module, nn.Linear):
                total_flops += module.in_features * module.out_features
            elif isinstance(module, nn.Conv1d):
                total_flops += (module.in_channels * module.out_channels * 
                              module.kernel_size[0])
        return total_flops
    
    def print_metrics(self):
        """Pretty print all metrics"""
        print("\n" + "="*60)
        print("COMPREHENSIVE EVALUATION RESULTS (23 Metrics)")
        print("="*60)
        
        categories = {
            'Accuracy Metrics': ['accuracy', 'balanced_accuracy', 'top_5_accuracy'],
            'Classification Metrics': ['precision', 'recall', 'f1_score', 'mcc', 'cohen_kappa'],
            'ROC/AUC Metrics': ['roc_auc'],
            'Error Rates': ['mean_fpr', 'mean_fnr', 'mean_tpr'],
            'Calibration Metrics': ['brier_score', 'ece'],
            'Uncertainty Metrics': ['mean_entropy', 'std_entropy', 'aleatoric_uncertainty', 'epistemic_uncertainty'],
            'Temporal Metrics': ['temporal_consistency'],
            'Efficiency Metrics': ['avg_inference_time_ms', 'memory_usage_mb', 'flops'],
            'Model Complexity': ['num_parameters', 'num_trainable_params'],
            'Feature Metrics': ['feature_variance', 'feature_sparsity'],
            'Robustness': ['poisoning_resilience']
        }
        
        for category, metric_names in categories.items():
            print(f"\n{category}:")
            print("-" * 40)
            for metric in metric_names:
                if metric in self.metrics:
                    value = self.metrics[metric]
                    if isinstance(value, float):
                        print(f"  {metric:30s}: {value:.4f}")
                    else:
                        print(f"  {metric:30s}: {value}")




# ============================================================================
# UNIFIED TAXONOMY AND FEATURE PROCESSING
# ============================================================================

In [None]:


class UnifiedTaxonomy:
    """
    Unified attack taxonomy across datasets
    Reference: Section V.A.1
    """
    
    def __init__(self):
        self.taxonomy = {
            'Normal': ['Normal/Benign', 'BENIGN', 'Benign', 'Normal', '0'],
            'DoS/DDoS': [
                'DoS', 'DDoS', 'DDOS-SLOWLORIS', 'DDOS-SYNONYMOUSIP_FLOOD',
                'DDOS-ICMP_FLOOD', 'DDOS-RSTFINFLOOD', 'DDOS-PSHACK_FLOOD',
                'DDOS-SYN_FLOOD', 'DDOS-TCP_FLOOD', 'DDOS-UDP_FLOOD',
                'DOS-UDP_FLOOD', 'DOS-SYN_FLOOD', 'DOS-TCP_FLOOD',
                'DoS_Hulk', 'DoS_GoldenEye', 'DoS_Slowloris', 'DoS_Slowhttptest'
            ],
            'Reconnaissance': [
                'Scanning', 'RECON-PORTSCAN', 'RECON-OSSCAN', 'RECON-HOSTDISCOVERY',
                'RECON-PINGSWEEP', 'VULNERABILITYSCAN', 'Heartbleed', 'PortScan',
                'Reconnaissance', 'Analysis'
            ],
            'Malware': [
                'BACKDOOR_MALWARE', 'Rootkit', 'Trojan', 'Worm', 'Botnet',
                'Malware', 'Bot', 'Mirai', 'Generic', 'Shellcode'
            ],
            'Injection': [
                'SQL_Injection', 'SQLINJECTION', 'COMMANDINJECTION', 'XSS',
                'Web Attack', 'Web-based'
            ],
            'BruteForce': [
                'DICTIONARYBRUTEFORCE', 'Brute_Force', 'FTP_Patator',
                'Password_Attack', 'Brute Force', 'SSH-Patator'
            ],
            'Exploitation': [
                'Infiltration', 'Backdoor', 'Exploits', 'Fuzzers'
            ]
        }
        
        # Create reverse mapping
        self.reverse_map = {}
        for category, attacks in self.taxonomy.items():
            for attack in attacks:
                self.reverse_map[attack.lower()] = category
    
    def map_label(self, label: str) -> str:
        """Map specific attack to unified category"""
        label_lower = str(label).lower().strip()
        return self.reverse_map.get(label_lower, 'Other')


class UnifiedFeatureProcessor:
    """
    Process and align features to 84 unified dimensions
    Reference: Section V.A.1
    """
    
    def __init__(self, target_features: int = 84):
        self.target_features = target_features
        self.scaler = StandardScaler()
        self.feature_names = self._define_unified_features()
    
    def _define_unified_features(self) -> List[str]:
        """Define 84 unified network flow features"""
        return [
            # Packet statistics (1-10)
            'flow_duration', 'total_fwd_packets', 'total_bwd_packets',
            'total_length_fwd_packets', 'total_length_bwd_packets',
            'fwd_packet_length_max', 'fwd_packet_length_min', 
            'fwd_packet_length_mean', 'fwd_packet_length_std', 'bwd_packet_length_max',
            
            # Flow statistics (11-20)
            'bwd_packet_length_min', 'bwd_packet_length_mean', 'bwd_packet_length_std',
            'flow_bytes_per_sec', 'flow_packets_per_sec', 
            'flow_iat_mean', 'flow_iat_std', 'flow_iat_max', 'flow_iat_min',
            'fwd_iat_total',
            
            # Inter-arrival times (21-30)
            'fwd_iat_mean', 'fwd_iat_std', 'fwd_iat_max', 'fwd_iat_min',
            'bwd_iat_total', 'bwd_iat_mean', 'bwd_iat_std', 
            'bwd_iat_max', 'bwd_iat_min', 'fwd_psh_flags',
            
            # TCP flags (31-42)
            'bwd_psh_flags', 'fwd_urg_flags', 'bwd_urg_flags',
            'fin_flag_count', 'syn_flag_count', 'rst_flag_count', 
            'psh_flag_count', 'ack_flag_count', 'urg_flag_count', 
            'cwe_flag_count', 'ece_flag_count', 'fwd_header_length',
            
            # Header and packet info (43-60)
            'bwd_header_length', 'fwd_packets_per_sec', 'bwd_packets_per_sec',
            'min_packet_length', 'max_packet_length', 'packet_length_mean',
            'packet_length_std', 'packet_length_variance', 'down_up_ratio',
            'average_packet_size', 'fwd_segment_size_avg', 'bwd_segment_size_avg',
            'fwd_bulk_rate_avg', 'bwd_bulk_rate_avg', 'subflow_fwd_packets',
            'subflow_fwd_bytes', 'subflow_bwd_packets', 'subflow_bwd_bytes',
            
            # Window and segment features (61-75)
            'init_win_bytes_forward', 'init_win_bytes_backward',
            'act_data_pkt_fwd', 'min_seg_size_forward',
            'active_mean', 'active_std', 'active_max', 'active_min',
            'idle_mean', 'idle_std', 'idle_max', 'idle_min',
            'protocol', 'src_port', 'dst_port',
            
            # Additional features (76-84)
            'fwd_bytes_bulk_avg', 'fwd_packet_bulk_avg', 'fwd_bulk_size_avg',
            'bwd_bytes_bulk_avg', 'bwd_packet_bulk_avg', 'bwd_bulk_size_avg',
            'fwd_subflow_packets', 'fwd_subflow_bytes', 'bwd_subflow_bytes'
        ]
    
    def process_features(self, df: pd.DataFrame) -> np.ndarray:
        """Process and align features to unified format"""
        processed = np.zeros((len(df), self.target_features))
        
        # Map available features
        for i, feature in enumerate(self.feature_names[:self.target_features]):
            if feature in df.columns:
                processed[:, i] = pd.to_numeric(df[feature], errors='coerce').fillna(0).values
            # Try alternate naming conventions
            elif feature.replace('_', ' ').title() in df.columns:
                col = feature.replace('_', ' ').title()
                processed[:, i] = pd.to_numeric(df[col], errors='coerce').fillna(0).values
        
        # Handle infinite values
        processed = np.nan_to_num(processed, nan=0.0, posinf=1e10, neginf=-1e10)
        
        # Clip extreme values
        processed = np.clip(processed, -1e10, 1e10)
        
        return processed




# ============================================================================
# DATASET LOADER
# ============================================================================


In [None]:

class MambaShieldDataset(Dataset):
    """
    Custom dataset for MambaShield
    """
    
    def __init__(self, features: np.ndarray, labels: np.ndarray, 
                 seq_len: int = 10, transform=None):
        self.features = features
        self.labels = labels
        self.seq_len = seq_len
        self.transform = transform
        
        # Create sequences
        self.sequences = []
        self.seq_labels = []
        
        for i in range(len(features) - seq_len + 1):
            self.sequences.append(features[i:i+seq_len])
            self.seq_labels.append(labels[i+seq_len-1])
        
        self.sequences = np.array(self.sequences)
        self.seq_labels = np.array(self.seq_labels)
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        label = self.seq_labels[idx]
        
        if self.transform:
            sequence = self.transform(sequence)
        
        return torch.FloatTensor(sequence), torch.LongTensor([label])[0]




# ============================================================================
# TRAINING FUNCTIONS
# ============================================================================

In [None]:


def train_epoch(model: nn.Module, dataloader: DataLoader, optimizer: torch.optim.Optimizer,
                criterion: nn.Module, device: torch.device, config: MambaShieldConfig,
                scaler: GradScaler, attack_simulator: Optional[PoisoningAttackSimulator] = None,
                epoch: int = 0) -> Dict[str, float]:
    """
    Train for one epoch with optional poisoning attacks
    """
    model.train()
    
    total_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (data, target) in enumerate(dataloader):
        data, target = data.to(device), target.to(device)
        
        # Simulate poisoning attacks (25% of batches)
        if attack_simulator and random.random() < 0.25:
            if random.random() < 0.5:
                data = attack_simulator.gradient_based_poisoning(model, data, target)
            else:
                target = attack_simulator.label_flipping(target, 7)  # 7 unified categories
        
        optimizer.zero_grad()
        
        # Mixed precision training
        with autocast():
            outputs = model(data)
            loss = criterion(outputs, target)
        
        # Backward pass
        scaler.scale(loss).backward()
        
        # Gradient clipping
        scaler.unscale_(optimizer)
        clip_grad_norm_(model.parameters(), config.gradient_clip)
        
        # Optimizer step
        scaler.step(optimizer)
        scaler.update()
        
        # Metrics
        total_loss += loss.item()
        pred = outputs.argmax(dim=1)
        correct += (pred == target).sum().item()
        total += target.size(0)
    
    metrics = {
        'loss': total_loss / len(dataloader),
        'accuracy': correct / total
    }
    
    return metrics


def validate(model: nn.Module, dataloader: DataLoader, criterion: nn.Module,
            device: torch.device) -> Dict[str, float]:
    """
    Validate model
    """
    model.eval()
    
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            
            outputs = model(data)
            loss = criterion(outputs, target)
            
            total_loss += loss.item()
            pred = outputs.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += target.size(0)
    
    metrics = {
        'loss': total_loss / len(dataloader),
        'accuracy': correct / total
    }
    
    return metrics




# ============================================================================
# MAIN TRAINING PIPELINE
# ============================================================================

In [None]:

def train_mambashield(config: MambaShieldConfig, 
                     train_data: Tuple[np.ndarray, np.ndarray],
                     val_data: Tuple[np.ndarray, np.ndarray],
                     test_data: Tuple[np.ndarray, np.ndarray]) -> Dict:
    """
    Complete training pipeline for MambaShield
    Reference: Algorithm 2 and Section V
    """
    
    print("="*60)
    print("MAMBASHIELD TRAINING PIPELINE")
    print("Citation: Anaedevha et al., IEEE TAI 2025")
    print("="*60)
    
    # Device setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Memory manager
    memory_manager = MemoryManager(device)
    memory_manager.clear_cache()
    
    # Create datasets
    train_dataset = MambaShieldDataset(train_data[0], train_data[1], seq_len=10)
    val_dataset = MambaShieldDataset(val_data[0], val_data[1], seq_len=10)
    test_dataset = MambaShieldDataset(test_data[0], test_data[1], seq_len=10)
    
    # Dataloaders
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, 
                            shuffle=True, num_workers=0, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size, 
                          shuffle=False, num_workers=0, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=config.batch_size, 
                           shuffle=False, num_workers=0, pin_memory=True)
    
    # Initialize model
    model = MambaShieldModel(config).to(device)
    print(f"Model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
    
    # Optimizer and scheduler
    optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate, 
                           weight_decay=config.weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.epochs)
    criterion = nn.CrossEntropyLoss()
    scaler = GradScaler()
    
    # Initialize components
    attack_simulator = PoisoningAttackSimulator(config)
    evaluator = ComprehensiveEvaluator(config)
    pac_bayes = PACBayesRegularizer(model, config)
    
    # Progressive ARD setup
    pard = ProgressiveARD(model, config)
    
    # Training history
    history = defaultdict(list)
    best_val_acc = 0
    best_model_state = None
    
    # Training loop
    print("\nStarting training...")
    for epoch in range(config.epochs):
        print(f"\nEpoch {epoch+1}/{config.epochs}")
        print("-" * 40)
        
        # Train
        train_metrics = train_epoch(
            model, train_loader, optimizer, criterion, 
            device, config, scaler, attack_simulator, epoch
        )
        
        # Validate
        val_metrics = validate(model, val_loader, criterion, device)
        
        # Update learning rate
        scheduler.step()
        
        # Save history
        for key, value in train_metrics.items():
            history[f'train_{key}'].append(value)
        for key, value in val_metrics.items():
            history[f'val_{key}'].append(value)
        
        # Print metrics
        print(f"Train Loss: {train_metrics['loss']:.4f}, Train Acc: {train_metrics['accuracy']:.4f}")
        print(f"Val Loss: {val_metrics['loss']:.4f}, Val Acc: {val_metrics['accuracy']:.4f}")
        
        # Save best model
        if val_metrics['accuracy'] > best_val_acc:
            best_val_acc = val_metrics['accuracy']
            best_model_state = model.state_dict()
            print(f"New best model! Val Acc: {best_val_acc:.4f}")
        
        # Memory management
        if epoch % 5 == 0:
            memory_manager.clear_cache()
            mem_info = memory_manager.get_memory_info()
            if mem_info:
                print(f"GPU Memory: {mem_info['allocated_gb']:.2f}GB / {mem_info['free_gb']:.2f}GB free")
    
    # Load best model
    if best_model_state:
        model.load_state_dict(best_model_state)
    
    # Final evaluation
    print("\n" + "="*60)
    print("FINAL EVALUATION ON TEST SET")
    print("="*60)
    
    final_metrics = evaluator.evaluate(model, test_loader, device, attack_simulator)
    evaluator.print_metrics()
    
    # PAC-Bayes bounds
    test_loss = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            test_loss += criterion(outputs, target).item()
    test_loss /= len(test_loader)
    
    pac_bounds = pac_bayes.pac_bayes_bound(torch.tensor(test_loss), len(test_dataset))
    
    print("\nPAC-Bayes Bounds (Theorem 3):")
    for key, value in pac_bounds.items():
        print(f"  {key}: {value:.4f}")
    
    # Plot results
    plot_training_history(history)
    
    return {
        'model': model,
        'history': history,
        'final_metrics': final_metrics,
        'pac_bounds': pac_bounds
    }


def plot_training_history(history: Dict):
    """Plot training history"""
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    # Loss plot
    axes[0].plot(history['train_loss'], label='Train')
    axes[0].plot(history['val_loss'], label='Validation')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training Loss')
    axes[0].legend()
    axes[0].grid(True)
    
    # Accuracy plot
    axes[1].plot(history['train_accuracy'], label='Train')
    axes[1].plot(history['val_accuracy'], label='Validation')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy')
    axes[1].set_title('Training Accuracy')
    axes[1].legend()
    axes[1].grid(True)
    
    plt.tight_layout()
    plt.show()



# ============================================================================
# EXAMPLE USAGE
# ============================================================================

In [None]:
if __name__ == "__main__":
    print("MambaShield: Complete Implementation")
    print("Paper: IEEE TAI 2025")
    print("="*60)
    
    # Initialize configuration
    config = MambaShieldConfig()
    
    # Set random seeds for reproducibility
    torch.manual_seed(config.random_seed)
    np.random.seed(config.random_seed)
    random.seed(config.random_seed)
    
    # Example: Load and process data
    # Note: Replace with actual dataset paths
    print("\n1. Loading datasets...")
    
    # Simulated data for demonstration
    # In practice, load CIC-IoT-2023, CSE-CICIDS2018, UNSW-NB15
    n_samples = 10000
    X = np.random.randn(n_samples, 84)  # 84 unified features
    y = np.random.randint(0, 7, n_samples)  # 7 attack categories
    
    # Split data
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=config.test_size, random_state=config.random_seed
    )
    X_train, X_val, y_train, y_val = train_test_split(
        X_train, y_train, test_size=config.val_size, random_state=config.random_seed
    )
    
    print(f"Train: {X_train.shape}, Val: {X_val.shape}, Test: {X_test.shape}")
    
    # Normalize features
    feature_processor = UnifiedFeatureProcessor()
    X_train = feature_processor.scaler.fit_transform(X_train)
    X_val = feature_processor.scaler.transform(X_val)
    X_test = feature_processor.scaler.transform(X_test)
    
    # Train model
    print("\n2. Training MambaShield...")
    results = train_mambashield(
        config,
        train_data=(X_train, y_train),
        val_data=(X_val, y_val),
        test_data=(X_test, y_test)
    )
    
    print("\n" + "="*60)
    print("TRAINING COMPLETE!")
    print(f"Final Test Accuracy: {results['final_metrics']['accuracy']:.4f}")
    print(f"Poisoning Resilience: {results['final_metrics'].get('poisoning_resilience', 0):.4f}")
    print("="*60) 
