#!/usr/bin/env python
## coding: utf-8

## MambaShield: Complete Implementation for IEEE TNNLS
## Temporal-Aware Poisoning-Resilient NIDS with Selective State Space Models
# 
## This notebook implements the cutting-edge MambaShield architecture optimized for Kaggle P100 GPU.

## Environment Setup and Imports

## Import Libraries

In [None]:
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader, TensorDataset, Dataset
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import gc
import traceback
import random
import time
import warnings
from typing import Dict, List, Tuple, Optional, Union, Any
from collections import deque, defaultdict
import math
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder, RobustScaler
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    confusion_matrix, roc_auc_score, roc_curve, auc,
    average_precision_score, matthews_corrcoef, cohen_kappa_score,
    classification_report, balanced_accuracy_score
)
import json
from datetime import datetime

warnings.filterwarnings('ignore')
plt.style.use('seaborn-v0_8-darkgrid')

print("PyTorch Version:", torch.__version__)
print("CUDA Available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))
    print("GPU Memory:", torch.cuda.get_device_properties(0).total_memory / 1e9, "GB")



## Advanced Memory Management for P100


In [None]:
class P100MemoryManager:
    """Advanced memory management for P100 GPU"""
    
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.memory_threshold = 0.8  # 80% memory usage threshold
        
    def get_memory_usage(self):
        """Get current GPU memory usage"""
        if torch.cuda.is_available():
            allocated = torch.cuda.memory_allocated() / 1e9
            cached = torch.cuda.memory_reserved() / 1e9
            total = torch.cuda.get_device_properties(0).total_memory / 1e9
            return {
                'allocated': allocated,
                'cached': cached,
                'total': total,
                'usage_percent': (allocated / total) * 100
            }
        return None
    
    def optimize_batch_size(self, model, input_shape, initial_batch=64):
        """Find optimal batch size for P100"""
        batch_size = initial_batch
        
        while batch_size > 1:
            try:
                # Test forward pass
                dummy_input = torch.randn(batch_size, *input_shape).to(self.device)
                _ = model(dummy_input)
                
                # Test backward pass
                loss = torch.randn(1, requires_grad=True).to(self.device)
                loss.backward()
                
                # Clear
                del dummy_input, loss
                self.clear_memory()
                
                # Check memory
                mem = self.get_memory_usage()
                if mem and mem['usage_percent'] < 70:
                    return batch_size
                    
            except RuntimeError as e:
                if "out of memory" in str(e):
                    self.clear_memory()
                    
            batch_size = batch_size // 2
            
        return max(batch_size, 1)
    
    def clear_memory(self):
        """Aggressive memory clearing"""
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()

memory_manager = P100MemoryManager()
device = memory_manager.device




## Enhanced Mamba Architecture with Fixes


In [None]:

class ImprovedSelectiveSSM(nn.Module):
    """Improved Selective State Space Model with numerical stability"""
    
    def __init__(self, d_model, d_state=16, d_conv=4, dt_rank="auto"):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        
        # Compute dt_rank
        if dt_rank == "auto":
            self.dt_rank = math.ceil(d_model / 16)
        else:
            self.dt_rank = dt_rank
        
        # Linear projections
        self.in_proj = nn.Linear(d_model, d_model * 2, bias=False)
        self.conv1d = nn.Conv1d(
            d_model, d_model, 
            kernel_size=d_conv,
            padding=d_conv - 1,
            groups=d_model
        )
        
        # SSM parameters
        self.x_proj = nn.Linear(d_model, self.dt_rank + d_state * 2, bias=False)
        self.dt_proj = nn.Linear(self.dt_rank, d_model, bias=True)
        
        # Initialize dt bias specially
        dt_init_std = self.dt_rank**-0.5 * 1.0
        nn.init.uniform_(self.dt_proj.bias, -dt_init_std, dt_init_std)
        
        # State parameters
        A = torch.arange(1, d_state + 1).repeat(d_model, 1)
        self.A_log = nn.Parameter(torch.log(A.float()))
        self.D = nn.Parameter(torch.ones(d_model))
        self.out_proj = nn.Linear(d_model, d_model, bias=False)
        
    def forward(self, x, inference_params=None):
        """Forward pass with numerical stability improvements"""
        batch, seqlen, dim = x.shape
        
        # Dual branch
        xz = self.in_proj(x)  # (B, L, 2*D)
        x, z = xz.chunk(2, dim=-1)  # Each (B, L, D)
        
        # Convolution with proper padding
        x = x.transpose(1, 2)  # (B, D, L)
        x = self.conv1d(x)[:, :, :seqlen]  # Ensure correct length
        x = x.transpose(1, 2)  # (B, L, D)
        
        # SSM
        x = F.silu(x)
        y = self.ssm(x)
        
        # Gating
        y = y * F.silu(z)
        output = self.out_proj(y)
        
        return output
    
    def ssm(self, x):
        """State Space Model computation with numerical stability"""
        batch, seqlen, dim = x.shape
        
        # Compute dt, B, C
        x_dbl = self.x_proj(x)  # (B, L, dt_rank + 2*d_state)
        
        dt, B, C = torch.split(
            x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1
        )
        
        dt = self.dt_proj(dt)  # (B, L, D)
        dt = F.softplus(dt)  # Ensure positive
        
        # State computation with numerical stability
        A = -torch.exp(self.A_log.float())  # (D, d_state)
        
        # Discretize (ZOH)
        y = self.selective_scan_simple(x, dt, A, B, C)
        
        return y
    
    def selective_scan_simple(self, u, dt, A, B, C):
        """Simplified selective scan for stability"""
        batch, seqlen, dim = u.shape
        d_state = A.shape[1]
        
        # Initialize output and state
        y = torch.zeros_like(u)
        x = torch.zeros(batch, dim, d_state, device=u.device, dtype=u.dtype)
        
        # Sequential computation (can be optimized with parallel scan)
        for i in range(seqlen):
            # Discretize
            deltaA = torch.exp(dt[:, i, :, None] * A[None, :, :])  # (B, D, d_state)
            deltaB = dt[:, i, :, None] * B[:, i, None, :]  # (B, D, d_state)
            
            # Update state
            x = deltaA * x + deltaB * u[:, i, :, None]
            
            # Compute output
            y[:, i] = (x * C[:, i, None, :]).sum(dim=-1)
        
        y = y + u * self.D
        
        return y




## Progressive Adversarial Robustness Distillation





In [None]:

class ProgressiveARD(nn.Module):
    """Progressive Adversarial Robustness Distillation"""
    
    def __init__(self, student, teachers, temperature=3.0):
        super().__init__()
        self.student = student
        self.teachers = nn.ModuleList(teachers) if teachers else nn.ModuleList()
        self.temperature = temperature
        self.alpha_schedule = self.create_schedule()
        
    def create_schedule(self):
        """Create progressive weighting schedule"""
        # Start with low weight, gradually increase
        return [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    
    def distillation_loss(self, student_logits, teacher_logits):
        """Compute distillation loss"""
        student_soft = F.log_softmax(student_logits / self.temperature, dim=-1)
        teacher_soft = F.softmax(teacher_logits / self.temperature, dim=-1)
        loss = F.kl_div(student_soft, teacher_soft, reduction='batchmean')
        return loss * (self.temperature ** 2)
    
    def forward(self, x, epoch=0):
        """Forward with progressive distillation"""
        student_output = self.student(x)
        
        if not self.training or len(self.teachers) == 0:
            return student_output, torch.tensor(0.0).to(x.device)
        
        # Get alpha based on epoch
        alpha_idx = min(epoch, len(self.alpha_schedule) - 1)
        alpha = self.alpha_schedule[alpha_idx]
        
        # Compute teacher ensemble predictions
        distill_loss = 0
        with torch.no_grad():
            for teacher in self.teachers:
                teacher.eval()
                teacher_output = teacher(x)
                if isinstance(teacher_output, dict):
                    teacher_logits = teacher_output['logits']
                else:
                    teacher_logits = teacher_output
                    
                if isinstance(student_output, dict):
                    student_logits = student_output['logits']
                else:
                    student_logits = student_output
                    
                distill_loss += self.distillation_loss(student_logits, teacher_logits)
        
        distill_loss = alpha * distill_loss / max(len(self.teachers), 1)
        
        return student_output, distill_loss




## 5. PAC-Bayes Certified Robustness


In [None]:
class PACBayesRobustness:
    """PAC-Bayes framework for certified robustness"""
    
    def __init__(self, model, prior_std=1.0, delta=0.05):
        self.model = model
        self.prior_std = prior_std
        self.delta = delta
        
    def compute_kl_divergence(self):
        """Compute KL divergence between posterior and prior"""
        kl = 0.0
        for param in self.model.parameters():
            if param.requires_grad:
                # Assume Gaussian prior centered at 0
                kl += 0.5 * torch.sum(
                    (param ** 2) / (self.prior_std ** 2) - 
                    torch.log(torch.ones_like(param)) + 
                    math.log(self.prior_std ** 2)
                )
        return kl
    
    def pac_bayes_bound(self, empirical_risk, n_samples):
        """Compute PAC-Bayes generalization bound"""
        kl = self.compute_kl_divergence()
        
        # McAllester's bound
        complexity = torch.sqrt(
            (kl + math.log(2 * math.sqrt(n_samples) / self.delta)) / 
            (2 * n_samples)
        )
        
        bound = empirical_risk + complexity
        
        return {
            'empirical_risk': empirical_risk.item(),
            'kl_divergence': kl.item(),
            'complexity': complexity.item(),
            'bound': bound.item()
        }
    
    def randomized_smoothing(self, x, num_samples=100, sigma=0.1):
        """Certify robustness using randomized smoothing"""
        predictions = []
        
        for _ in range(num_samples):
            noise = torch.randn_like(x) * sigma
            noisy_x = x + noise
            
            with torch.no_grad():
                output = self.model(noisy_x)
                if isinstance(output, dict):
                    logits = output['logits']
                else:
                    logits = output
                predictions.append(F.softmax(logits, dim=-1))
        
        # Aggregate predictions
        avg_pred = torch.stack(predictions).mean(0)
        certified_pred = avg_pred.argmax(dim=-1)
        
        # Compute certified radius (simplified)
        top2_probs, _ = avg_pred.topk(2, dim=-1)
        gap = top2_probs[:, 0] - top2_probs[:, 1]
        certified_radius = sigma * gap / 2
        
        return certified_pred, certified_radius



## Comprehensive Evaluation Framework (23 Metrics)



In [None]:

class ComprehensiveEvaluator:
    """23-metric comprehensive evaluation framework"""
    
    def __init__(self):
        self.metrics = {}
        
    def evaluate(self, model, dataloader, device, attack_fn=None):
        """Comprehensive evaluation with 23 metrics"""
        model.eval()
        
        all_preds = []
        all_labels = []
        all_probs = []
        all_features = []
        
        with torch.no_grad():
            for data, labels in dataloader:
                data, labels = data.to(device), labels.to(device)
                
                outputs = model(data)
                if isinstance(outputs, dict):
                    logits = outputs['logits']
                    if 'features' in outputs:
                        all_features.append(outputs['features'].cpu())
                else:
                    logits = outputs
                
                probs = F.softmax(logits, dim=-1)
                preds = logits.argmax(dim=-1)
                
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                all_probs.extend(probs.cpu().numpy())
        
        all_preds = np.array(all_preds)
        all_labels = np.array(all_labels)
        all_probs = np.array(all_probs)
        
        # 1. Accuracy Metrics
        self.metrics['accuracy'] = accuracy_score(all_labels, all_preds)
        self.metrics['balanced_accuracy'] = balanced_accuracy_score(all_labels, all_preds)
        
        # 2. Precision, Recall, F1
        self.metrics['precision'] = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
        self.metrics['recall'] = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
        self.metrics['f1_score'] = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
        
        # 3. ROC-AUC (if binary or can be computed)
        try:
            if len(np.unique(all_labels)) == 2:
                self.metrics['roc_auc'] = roc_auc_score(all_labels, all_probs[:, 1])
            else:
                self.metrics['roc_auc'] = roc_auc_score(all_labels, all_probs, multi_class='ovr')
        except:
            self.metrics['roc_auc'] = 0.0
        
        # 4. Matthews Correlation Coefficient
        self.metrics['mcc'] = matthews_corrcoef(all_labels, all_preds)
        
        # 5. Cohen's Kappa
        self.metrics['cohen_kappa'] = cohen_kappa_score(all_labels, all_preds)
        
        # 6. Confusion Matrix Metrics
        cm = confusion_matrix(all_labels, all_preds)
        
        # True Positive Rate (Sensitivity/Recall) per class
        tpr_per_class = np.diag(cm) / cm.sum(axis=1)
        self.metrics['mean_tpr'] = np.mean(tpr_per_class)
        
        # True Negative Rate (Specificity) per class
        fp = cm.sum(axis=0) - np.diag(cm)
        fn = cm.sum(axis=1) - np.diag(cm)
        tp = np.diag(cm)
        tn = cm.sum() - (fp + fn + tp)
        tnr_per_class = tn / (tn + fp)
        self.metrics['mean_tnr'] = np.mean(tnr_per_class)
        
        # 7. False Positive Rate
        fpr_per_class = fp / (fp + tn)
        self.metrics['mean_fpr'] = np.mean(fpr_per_class)
        
        # 8. False Negative Rate
        fnr_per_class = fn / (fn + tp)
        self.metrics['mean_fnr'] = np.mean(fnr_per_class)
        
        # 9. Uncertainty Metrics (using prediction entropy)
        pred_entropy = -np.sum(all_probs * np.log(all_probs + 1e-10), axis=1)
        self.metrics['mean_entropy'] = np.mean(pred_entropy)
        self.metrics['std_entropy'] = np.std(pred_entropy)
        
        # 10. Calibration Metrics
        self.metrics['brier_score'] = self._compute_brier_score(all_labels, all_probs)
        self.metrics['ece'] = self._compute_ece(all_labels, all_preds, all_probs)
        
        # 11. Detection Latency (simulated)
        self.metrics['avg_inference_time'] = self._measure_inference_time(model, dataloader, device)
        
        # 12. Memory Usage
        if torch.cuda.is_available():
            self.metrics['gpu_memory_mb'] = torch.cuda.memory_allocated() / 1e6
        
        # 13. Model Complexity
        self.metrics['num_parameters'] = sum(p.numel() for p in model.parameters())
        self.metrics['num_trainable_params'] = sum(p.numel() for p in model.parameters() if p.requires_grad)
        
        # 14. Attack Robustness (if attack function provided)
        if attack_fn:
            self.metrics['attack_success_rate'] = self._evaluate_attack_robustness(
                model, dataloader, device, attack_fn
            )
        
        # 15. Feature Quality Metrics (if features available)
        if all_features:
            features = torch.cat(all_features, dim=0).numpy()
            self.metrics['feature_variance'] = np.mean(np.var(features, axis=0))
            self.metrics['feature_sparsity'] = np.mean(features == 0)
        
        return self.metrics
    
    def _compute_brier_score(self, labels, probs):
        """Compute Brier score for calibration"""
        n_classes = probs.shape[1]
        brier = 0
        for i in range(len(labels)):
            label_onehot = np.zeros(n_classes)
            label_onehot[labels[i]] = 1
            brier += np.sum((probs[i] - label_onehot) ** 2)
        return brier / len(labels)
    
    def _compute_ece(self, labels, preds, probs, n_bins=10):
        """Expected Calibration Error"""
        max_probs = np.max(probs, axis=1)
        correct = (preds == labels).astype(float)
        
        ece = 0
        for bin_i in range(n_bins):
            bin_lower = bin_i / n_bins
            bin_upper = (bin_i + 1) / n_bins
            
            in_bin = (max_probs > bin_lower) & (max_probs <= bin_upper)
            if np.sum(in_bin) > 0:
                bin_acc = np.mean(correct[in_bin])
                bin_conf = np.mean(max_probs[in_bin])
                bin_size = np.sum(in_bin)
                ece += (bin_size / len(labels)) * abs(bin_acc - bin_conf)
        
        return ece
    
    def _measure_inference_time(self, model, dataloader, device, n_samples=10):
        """Measure average inference time"""
        times = []
        
        for i, (data, _) in enumerate(dataloader):
            if i >= n_samples:
                break
                
            data = data.to(device)
            
            start = time.time()
            with torch.no_grad():
                _ = model(data)
            torch.cuda.synchronize() if torch.cuda.is_available() else None
            times.append(time.time() - start)
        
        return np.mean(times) * 1000  # Convert to ms
    
    def _evaluate_attack_robustness(self, model, dataloader, device, attack_fn, n_samples=100):
        """Evaluate model robustness against attacks"""
        success_count = 0
        total_count = 0
        
        for i, (data, labels) in enumerate(dataloader):
            if total_count >= n_samples:
                break
                
            data, labels = data.to(device), labels.to(device)
            
            # Generate adversarial examples
            adv_data = attack_fn(model, data, labels)
            
            # Evaluate on adversarial examples
            with torch.no_grad():
                outputs = model(adv_data)
                if isinstance(outputs, dict):
                    logits = outputs['logits']
                else:
                    logits = outputs
                    
                adv_preds = logits.argmax(dim=-1)
                success_count += (adv_preds != labels).sum().item()
                total_count += len(labels)
        
        return success_count / max(total_count, 1)
    
    def print_metrics(self):
        """Pretty print all metrics"""
        print("\n" + "="*60)
        print("COMPREHENSIVE EVALUATION RESULTS (23 Metrics)")
        print("="*60)
        
        categories = {
            'Accuracy Metrics': ['accuracy', 'balanced_accuracy'],
            'Classification Metrics': ['precision', 'recall', 'f1_score', 'mcc', 'cohen_kappa'],
            'ROC/AUC Metrics': ['roc_auc'],
            'Error Rates': ['mean_fpr', 'mean_fnr', 'mean_tpr', 'mean_tnr'],
            'Calibration Metrics': ['brier_score', 'ece'],
            'Uncertainty Metrics': ['mean_entropy', 'std_entropy'],
            'Performance Metrics': ['avg_inference_time', 'gpu_memory_mb'],
            'Model Complexity': ['num_parameters', 'num_trainable_params'],
            'Feature Metrics': ['feature_variance', 'feature_sparsity'],
            'Robustness': ['attack_success_rate']
        }
        
        for category, metric_names in categories.items():
            print(f"\n{category}:")
            print("-" * 40)
            for metric in metric_names:
                if metric in self.metrics:
                    value = self.metrics[metric]
                    if isinstance(value, float):
                        print(f"  {metric:25s}: {value:.4f}")
                    else:
                        print(f"  {metric:25s}: {value}")




## Complete MambaShield Model


In [None]:
class MambaShield(nn.Module):
    """Complete MambaShield architecture with all components"""
    
    def __init__(self, input_dim, num_classes, config=None):
        super().__init__()
        
        # Default config
        if config is None:
            config = {
                'hidden_dim': 128,
                'd_state': 16,
                'n_layers': 2,
                'dropout': 0.1,
                'use_rl': False  # Disable RL by default for stability
            }
        
        self.config = config
        self.input_dim = input_dim
        self.num_classes = num_classes
        
        # Input projection
        self.input_proj = nn.Sequential(
            nn.Linear(input_dim, config['hidden_dim']),
            nn.LayerNorm(config['hidden_dim']),
            nn.ReLU()
        )
        
        # Mamba layers
        self.mamba_layers = nn.ModuleList([
            ImprovedSelectiveSSM(
                config['hidden_dim'],
                d_state=config['d_state']
            )
            for _ in range(config['n_layers'])
        ])
        
        # Layer norms
        self.layer_norms = nn.ModuleList([
            nn.LayerNorm(config['hidden_dim'])
            for _ in range(config['n_layers'])
        ])
        
        # Dropout
        self.dropout = nn.Dropout(config['dropout'])
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.LayerNorm(config['hidden_dim']),
            nn.Linear(config['hidden_dim'], config['hidden_dim'] // 2),
            nn.ReLU(),
            nn.Dropout(config['dropout']),
            nn.Linear(config['hidden_dim'] // 2, num_classes)
        )
        
    def forward(self, x, return_features=False):
        """Forward pass"""
        B, L, D = x.shape
        
        # Input projection
        x = self.input_proj(x)
        
        # Mamba layers with residual connections
        for i, (mamba, norm) in enumerate(zip(self.mamba_layers, self.layer_norms)):
            residual = x
            x = norm(x)
            x = mamba(x)
            x = self.dropout(x)
            x = x + residual
        
        # Use last timestep for classification
        features = x[:, -1, :]
        
        # Classification
        logits = self.classifier(features)
        
        outputs = {'logits': logits}
        
        if return_features:
            outputs['features'] = features
        
        return outputs




## Training Pipeline with Poisoning Attack Simulation


In [None]:
class PoisoningAttackSimulator:
    """Simulate various poisoning attacks"""
    
    def __init__(self, epsilon=0.1):
        self.epsilon = epsilon
    
    def gradient_based_poisoning(self, model, x, y):
        """Gradient-based poisoning attack"""
        x_adv = x.clone().detach().requires_grad_(True)
        
        outputs = model(x_adv)
        if isinstance(outputs, dict):
            logits = outputs['logits']
        else:
            logits = outputs
            
        loss = F.cross_entropy(logits, y)
        loss.backward()
        
        # Generate adversarial perturbation
        perturbation = self.epsilon * x_adv.grad.sign()
        x_poisoned = x + perturbation
        
        return x_poisoned.detach()
    
    def label_flipping(self, y, flip_rate=0.2):
        """Random label flipping attack"""
        y_flipped = y.clone()
        n_flip = int(len(y) * flip_rate)
        flip_indices = torch.randperm(len(y))[:n_flip]
        
        # Flip to random different class
        for idx in flip_indices:
            current_label = y_flipped[idx].item()
            new_label = random.choice([i for i in range(y.max().item() + 1) if i != current_label])
            y_flipped[idx] = new_label
        
        return y_flipped
    
    def backdoor_attack(self, x, trigger_pattern=None, target_class=0, poison_rate=0.1):
        """Backdoor poisoning attack"""
        x_backdoor = x.clone()
        n_poison = int(len(x) * poison_rate)
        poison_indices = torch.randperm(len(x))[:n_poison]
        
        if trigger_pattern is None:
            # Simple trigger: modify first few features
            trigger_pattern = torch.zeros_like(x[0])
            trigger_pattern[:, :5] = 1.0
        
        for idx in poison_indices:
            x_backdoor[idx] = x_backdoor[idx] + trigger_pattern
        
        return x_backdoor, poison_indices




# Main Training Function



In [None]:
def train_mambashield(config):
    """Main training function with all components"""
    
    print("="*60)
    print("MAMBASHIELD TRAINING PIPELINE")
    print("="*60)
    
    # Memory manager
    mem_manager = P100MemoryManager()
    
    # Load data (using smaller sample for demo)
    print("\n1. Loading Dataset...")
    dataset_path = '/kaggle/input/poisoning-i/CIC_IoT_M3.csv'
    
    # Load with sampling for P100
    df = pd.read_csv(dataset_path, nrows=50000)  # Limit rows for demo
    
    # Process data
    label_col = 'Label'
    feature_cols = [c for c in df.columns if c != label_col]
    
    # Handle non-numeric
    for col in feature_cols:
        df[col] = pd.to_numeric(df[col], errors='coerce')
    df = df.fillna(0)
    
    X = df[feature_cols].values
    y = df[label_col].values
    
    # Encode labels
    le = LabelEncoder()
    y = le.fit_transform(y.astype(str))
    
    # Scale features
    scaler = StandardScaler()
    X = scaler.fit_transform(X)
    
    print(f"Data shape: {X.shape}")
    print(f"Classes: {len(np.unique(y))}")
    
    # Create sequences
    seq_len = config['seq_len']
    sequences = []
    labels = []
    
    for i in range(len(X) - seq_len + 1):
        sequences.append(X[i:i+seq_len])
        labels.append(y[i+seq_len-1])
    
    X_seq = np.array(sequences)
    y_seq = np.array(labels)
    
    # Split data
    X_train, X_test, y_train, y_test = train_test_split(
        X_seq, y_seq, test_size=0.2, random_state=42, stratify=y_seq
    )
    
    X_train, X_val, y_train, y_val = train_test_split(
        X_train, y_train, test_size=0.2, random_state=42, stratify=y_train
    )
    
    print(f"\nTrain: {X_train.shape}, Val: {X_val.shape}, Test: {X_test.shape}")
    
    # Create datasets
    train_dataset = TensorDataset(
        torch.FloatTensor(X_train),
        torch.LongTensor(y_train)
    )
    val_dataset = TensorDataset(
        torch.FloatTensor(X_val),
        torch.LongTensor(y_val)
    )
    test_dataset = TensorDataset(
        torch.FloatTensor(X_test),
        torch.LongTensor(y_test)
    )
    
    # Dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=0
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=0
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=0
    )
    
    # Create model
    print("\n2. Creating MambaShield Model...")
    model = MambaShield(
        input_dim=X.shape[1],
        num_classes=len(np.unique(y)),
        config={
            'hidden_dim': config['hidden_dim'],
            'd_state': config['d_state'],
            'n_layers': config['n_layers'],
            'dropout': config['dropout']
        }
    ).to(device)
    
    print(f"Model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
    
    # Optimizer
    optimizer = optim.AdamW(model.parameters(), lr=config['lr'], weight_decay=1e-5)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['epochs'])
    criterion = nn.CrossEntropyLoss()
    scaler = GradScaler()
    
    # Initialize components
    attack_sim = PoisoningAttackSimulator(epsilon=0.1)
    evaluator = ComprehensiveEvaluator()
    pac_bayes = PACBayesRobustness(model)
    
    # Training history
    history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': [],
        'metrics': []
    }
    
    # Training loop
    print("\n3. Starting Training...")
    best_val_acc = 0
    
    for epoch in range(config['epochs']):
        print(f"\n" + "="*50)
        print(f"Epoch {epoch+1}/{config['epochs']}")
        print("="*50)
        
        # Training
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0
        
        pbar = tqdm(train_loader, desc='Training')
        for batch_idx, (data, target) in enumerate(pbar):
            data, target = data.to(device), target.to(device)
            
            # Simulate poisoning attacks (25% of batches)
            if random.random() < 0.25:
                if random.random() < 0.5:
                    data = attack_sim.gradient_based_poisoning(model, data, target)
                else:
                    target = attack_sim.label_flipping(target, flip_rate=0.1)
            
            optimizer.zero_grad()
            
            # Mixed precision training
            with autocast():
                outputs = model(data)
                loss = criterion(outputs['logits'], target)
            
            scaler.scale(loss).backward()
            
            # Gradient clipping
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            scaler.step(optimizer)
            scaler.update()
            
            # Metrics
            train_loss += loss.item()
            pred = outputs['logits'].argmax(dim=1)
            train_correct += (pred == target).sum().item()
            train_total += target.size(0)
            
            # Update progress bar
            pbar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'acc': f"{100.*train_correct/train_total:.2f}%"
            })
            
            # Memory cleanup
            if batch_idx % 20 == 0:
                mem_manager.clear_memory()
        
        # Calculate epoch metrics
        train_loss /= len(train_loader)
        train_acc = 100. * train_correct / train_total
        
        # Validation
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for data, target in tqdm(val_loader, desc='Validation'):
                data, target = data.to(device), target.to(device)
                
                outputs = model(data)
                loss = criterion(outputs['logits'], target)
                
                val_loss += loss.item()
                pred = outputs['logits'].argmax(dim=1)
                val_correct += (pred == target).sum().item()
                val_total += target.size(0)
        
        val_loss /= len(val_loader)
        val_acc = 100. * val_correct / val_total
        
        # Update scheduler
        scheduler.step()
        
        # Save history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        print(f"\nEpoch {epoch+1} Summary:")
        print(f"Train - Loss: {train_loss:.4f}, Acc: {train_acc:.2f}%")
        print(f"Val   - Loss: {val_loss:.4f}, Acc: {val_acc:.2f}%")
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = model.state_dict()
            print(f"New best model! Val Acc: {best_val_acc:.2f}%")
        
        # Memory cleanup
        mem_manager.clear_memory()
    
    # Load best model
    model.load_state_dict(best_model_state)
    
    # Final evaluation
    print("\n4. Final Evaluation on Test Set...")
    metrics = evaluator.evaluate(model, test_loader, device)
    evaluator.print_metrics()
    
    # PAC-Bayes bound
    print("\n5. Computing PAC-Bayes Robustness Bound...")
    with torch.no_grad():
        test_loss = 0
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            loss = criterion(outputs['logits'], target)
            test_loss += loss.item()
        test_loss /= len(test_loader)
    
    pac_bounds = pac_bayes.pac_bayes_bound(
        torch.tensor(test_loss),
        len(test_dataset)
    )
    
    print("\nPAC-Bayes Bounds:")
    for key, value in pac_bounds.items():
        print(f"  {key}: {value:.4f}")
    
    # Plot training history
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    axes[0].plot(history['train_loss'], label='Train')
    axes[0].plot(history['val_loss'], label='Val')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training Loss')
    axes[0].legend()
    axes[0].grid(True)
    
    axes[1].plot(history['train_acc'], label='Train')
    axes[1].plot(history['val_acc'], label='Val')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy (%)')
    axes[1].set_title('Training Accuracy')
    axes[1].legend()
    axes[1].grid(True)
    
    plt.tight_layout()
    plt.show()
    
    return model, history, metrics




# Execute Training


In [None]:
if __name__ == "__main__":
    # Configuration
    config = {
        'seq_len': 10,        # Sequence length
        'batch_size': 32,     # Batch size for P100
        'hidden_dim': 128,    # Hidden dimension
        'd_state': 16,        # SSM state dimension
        'n_layers': 2,        # Number of Mamba layers
        'dropout': 0.1,       # Dropout rate
        'lr': 1e-4,          # Learning rate
        'epochs': 10,         # Number of epochs
    }
    
    print("Configuration:")
    for key, value in config.items():
        print(f"  {key}: {value}")
    
    # Train model
    model, history, metrics = train_mambashield(config)
    
    print("\n" + "="*60)
    print("TRAINING COMPLETE!")
    print("="*60)
    print(f"Best validation accuracy: {max(history['val_acc']):.2f}%")
    print(f"Final test accuracy: {metrics['accuracy']*100:.2f}%")
    
    # Save model
    torch.save({
        'model_state_dict': model.state_dict(),
        'config': config,
        'metrics': metrics,
        'history': history
    }, 'mambashield_model.pth')
    
    print("\nModel saved to 'mambashield_model.pth'") 



# NEW ADDITIONAL COMPREHENSIVE CODE FOR MAMBASHIELD

In [None]:
#!/usr/bin/env python
# coding: utf-8

"""
MambaShield: Complete Multi-Dataset Implementation
Handles CIC-IoT-2023, CSE-CICIDS2018, and UNSW-NB15 with unified taxonomy
"""

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader, TensorDataset, ConcatDataset
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import gc
import warnings
from typing import Dict, List, Tuple, Optional, Union
from collections import defaultdict
import math
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    confusion_matrix, roc_auc_score, classification_report
)

warnings.filterwarnings('ignore')

# ============================================================================
# UNIFIED ATTACK TAXONOMY
# ============================================================================

class UnifiedTaxonomy:
    """Unified attack taxonomy across all datasets"""
    
    def __init__(self):
        self.taxonomy = {
            'Normal': ['Normal/Benign', 'BENIGN', 'Benign', 'Normal', '0'],
            'DoS/DDoS': [
                'DoS', 'DDoS', 'DDOS-SLOWLORIS', 'DDOS-SYNONYMOUSIP_FLOOD', 
                'DDOS-ICMP_FLOOD', 'DDOS-RSTFINFLOOD', 'DDOS-PSHACK_FLOOD',
                'DDOS-SYN_FLOOD', 'DDOS-TCP_FLOOD', 'DDOS-UDP_FLOOD',
                'DOS-UDP_FLOOD', 'DOS-SYN_FLOOD', 'DOS-TCP_FLOOD',
                'DoS_Hulk', 'DoS_GoldenEye', 'DoS_Slowloris', 'DoS_Slowhttptest'
            ],
            'Reconnaissance': [
                'Scanning', 'RECON-PORTSCAN', 'RECON-OSSCAN', 'RECON-HOSTDISCOVERY',
                'RECON-PINGSWEEP', 'VULNERABILITYSCAN', 'Heartbleed', 'PortScan',
                'Reconnaissance', 'Analysis'
            ],
            'Malware': [
                'BACKDOOR_MALWARE', 'Rootkit', 'Trojan', 'Worm', 'Botnet', 
                'Malware', 'Bot', 'Mirai', 'Generic', 'Shellcode'
            ],
            'Injection': [
                'SQL_Injection', 'SQLINJECTION', 'COMMANDINJECTION', 'XSS',
                'Web Attack', 'Web-based'
            ],
            'BruteForce': [
                'DICTIONARYBRUTEFORCE', 'Brute_Force', 'FTP_Patator', 
                'Password_Attack', 'Brute Force', 'SSH-Patator'
            ],
            'Exploitation': [
                'Infiltration', 'Backdoor', 'Exploits', 'Fuzzers'
            ]
        }
        
        # Create reverse mapping
        self.reverse_map = {}
        for category, attacks in self.taxonomy.items():
            for attack in attacks:
                self.reverse_map[attack.lower()] = category
    
    def map_label(self, label):
        """Map specific attack to unified category"""
        label_lower = str(label).lower().strip()
        return self.reverse_map.get(label_lower, 'Other')


# ============================================================================
# UNIFIED FEATURE PROCESSOR
# ============================================================================

class UnifiedFeatureProcessor:
    """Process and align features across datasets"""
    
    def __init__(self, target_features=84):
        self.target_features = target_features
        self.common_features = self._define_common_features()
        self.scaler = StandardScaler()
        
    def _define_common_features(self):
        """Define common network flow features"""
        return [
            # Packet statistics
            'flow_duration', 'total_fwd_packets', 'total_bwd_packets',
            'total_length_fwd_packets', 'total_length_bwd_packets',
            
            # Packet length statistics
            'fwd_packet_length_max', 'fwd_packet_length_min', 'fwd_packet_length_mean',
            'fwd_packet_length_std', 'bwd_packet_length_max', 'bwd_packet_length_min',
            'bwd_packet_length_mean', 'bwd_packet_length_std',
            
            # Flow statistics
            'flow_bytes_per_sec', 'flow_packets_per_sec', 'flow_iat_mean',
            'flow_iat_std', 'flow_iat_max', 'flow_iat_min',
            
            # Inter-arrival times
            'fwd_iat_total', 'fwd_iat_mean', 'fwd_iat_std', 'fwd_iat_max', 'fwd_iat_min',
            'bwd_iat_total', 'bwd_iat_mean', 'bwd_iat_std', 'bwd_iat_max', 'bwd_iat_min',
            
            # TCP flags
            'fwd_psh_flags', 'bwd_psh_flags', 'fwd_urg_flags', 'bwd_urg_flags',
            'fin_flag_count', 'syn_flag_count', 'rst_flag_count', 'psh_flag_count',
            'ack_flag_count', 'urg_flag_count', 'cwe_flag_count', 'ece_flag_count',
            
            # Header information
            'fwd_header_length', 'bwd_header_length', 'fwd_packets_per_sec',
            'bwd_packets_per_sec', 'min_packet_length', 'max_packet_length',
            'packet_length_mean', 'packet_length_std', 'packet_length_variance',
            
            # Additional flow features
            'down_up_ratio', 'average_packet_size', 'fwd_segment_size_avg',
            'bwd_segment_size_avg', 'fwd_bulk_rate_avg', 'bwd_bulk_rate_avg',
            'subflow_fwd_packets', 'subflow_fwd_bytes', 'subflow_bwd_packets',
            'subflow_bwd_bytes', 'init_win_bytes_forward', 'init_win_bytes_backward',
            'act_data_pkt_fwd', 'min_seg_size_forward', 'active_mean', 'active_std',
            'active_max', 'active_min', 'idle_mean', 'idle_std', 'idle_max', 'idle_min',
            
            # Protocol features
            'protocol', 'src_port', 'dst_port'
        ]
    
    def align_features(self, df, dataset_name):
        """Align dataset features to common feature set"""
        print(f"Aligning features for {dataset_name}...")
        
        # Feature name mappings for different datasets
        feature_mappings = {
            'cic': {
                'Flow Duration': 'flow_duration',
                'Total Fwd Packets': 'total_fwd_packets',
                'Total Backward Packets': 'total_bwd_packets',
                'Total Length of Fwd Packets': 'total_length_fwd_packets',
                'Total Length of Bwd Packets': 'total_length_bwd_packets',
                'Fwd Packet Length Max': 'fwd_packet_length_max',
                'Fwd Packet Length Min': 'fwd_packet_length_min',
                'Fwd Packet Length Mean': 'fwd_packet_length_mean',
                'Fwd Packet Length Std': 'fwd_packet_length_std',
                'Bwd Packet Length Max': 'bwd_packet_length_max',
                'Bwd Packet Length Min': 'bwd_packet_length_min',
                'Bwd Packet Length Mean': 'bwd_packet_length_mean',
                'Bwd Packet Length Std': 'bwd_packet_length_std',
                'Flow Bytes/s': 'flow_bytes_per_sec',
                'Flow Packets/s': 'flow_packets_per_sec',
                'Flow IAT Mean': 'flow_iat_mean',
                'Flow IAT Std': 'flow_iat_std',
                'Flow IAT Max': 'flow_iat_max',
                'Flow IAT Min': 'flow_iat_min',
                'Fwd IAT Total': 'fwd_iat_total',
                'Fwd IAT Mean': 'fwd_iat_mean',
                'Fwd IAT Std': 'fwd_iat_std',
                'Fwd IAT Max': 'fwd_iat_max',
                'Fwd IAT Min': 'fwd_iat_min',
                'Bwd IAT Total': 'bwd_iat_total',
                'Bwd IAT Mean': 'bwd_iat_mean',
                'Bwd IAT Std': 'bwd_iat_std',
                'Bwd IAT Max': 'bwd_iat_max',
                'Bwd IAT Min': 'bwd_iat_min',
                'Fwd PSH Flags': 'fwd_psh_flags',
                'Bwd PSH Flags': 'bwd_psh_flags',
                'Fwd URG Flags': 'fwd_urg_flags',
                'Bwd URG Flags': 'bwd_urg_flags',
                'FIN Flag Count': 'fin_flag_count',
                'SYN Flag Count': 'syn_flag_count',
                'RST Flag Count': 'rst_flag_count',
                'PSH Flag Count': 'psh_flag_count',
                'ACK Flag Count': 'ack_flag_count',
                'URG Flag Count': 'urg_flag_count',
                'CWE Flag Count': 'cwe_flag_count',
                'ECE Flag Count': 'ece_flag_count',
                'Protocol': 'protocol',
                'Source Port': 'src_port',
                'Destination Port': 'dst_port'
            },
            'cse': {
                # Similar mappings for CSE-CICIDS2018
                'Flow Duration': 'flow_duration',
                'Tot Fwd Pkts': 'total_fwd_packets',
                'Tot Bwd Pkts': 'total_bwd_packets',
                'TotLen Fwd Pkts': 'total_length_fwd_packets',
                'TotLen Bwd Pkts': 'total_length_bwd_packets',
                'Fwd Pkt Len Max': 'fwd_packet_length_max',
                'Fwd Pkt Len Min': 'fwd_packet_length_min',
                'Fwd Pkt Len Mean': 'fwd_packet_length_mean',
                'Fwd Pkt Len Std': 'fwd_packet_length_std',
                'Bwd Pkt Len Max': 'bwd_packet_length_max',
                'Bwd Pkt Len Min': 'bwd_packet_length_min',
                'Bwd Pkt Len Mean': 'bwd_packet_length_mean',
                'Bwd Pkt Len Std': 'bwd_packet_length_std',
                'Flow Byts/s': 'flow_bytes_per_sec',
                'Flow Pkts/s': 'flow_packets_per_sec',
                'Protocol': 'protocol'
            },
            'ton': {
                # Mappings for UNSW-TON-NB15
                'dur': 'flow_duration',
                'spkts': 'total_fwd_packets',
                'dpkts': 'total_bwd_packets',
                'sbytes': 'total_length_fwd_packets',
                'dbytes': 'total_length_bwd_packets',
                'rate': 'flow_packets_per_sec',
                'proto': 'protocol',
                'sport': 'src_port',
                'dport': 'dst_port'
            }
        }
        
        # Get appropriate mapping
        mapping = feature_mappings.get(dataset_name.lower(), {})
        
        # Rename columns
        df_renamed = df.rename(columns=mapping)
        
        # Create aligned dataframe with common features
        aligned_df = pd.DataFrame()
        
        for feature in self.common_features[:self.target_features]:
            if feature in df_renamed.columns:
                aligned_df[feature] = df_renamed[feature]
            else:
                # Fill missing features with zeros or statistical values
                aligned_df[feature] = 0
        
        return aligned_df
    
    def process_dataset(self, df, dataset_name, label_col):
        """Process entire dataset"""
        # Separate features and labels
        labels = df[label_col] if label_col in df.columns else df.iloc[:, -1]
        
        # Remove label column from features
        feature_df = df.drop(columns=[label_col], errors='ignore')
        
        # Align features
        aligned_features = self.align_features(feature_df, dataset_name)
        
        # Handle non-numeric data
        for col in aligned_features.columns:
            aligned_features[col] = pd.to_numeric(aligned_features[col], errors='coerce')
        
        # Fill NaN values
        aligned_features = aligned_features.fillna(0)
        
        # Clip extreme values
        aligned_features = aligned_features.clip(lower=-1e10, upper=1e10)
        
        return aligned_features, labels


# ============================================================================
# MULTI-DATASET LOADER
# ============================================================================

class MultiDatasetLoader:
    """Load and process all three datasets with unified taxonomy"""
    
    def __init__(self, config):
        self.config = config
        self.taxonomy = UnifiedTaxonomy()
        self.feature_processor = UnifiedFeatureProcessor(target_features=84)
        self.label_encoder = LabelEncoder()
        
    def load_dataset(self, file_path, dataset_name, sample_frac=1.0):
        """Load single dataset"""
        print(f"\nLoading {dataset_name} dataset...")
        
        # Determine label column
        label_columns = {
            'cic': 'Label',
            'cse': ' Label',
            'ton': 'label'
        }
        label_col = label_columns.get(dataset_name.lower(), 'Label')
        
        # Read dataset in chunks for memory efficiency
        chunks = []
        chunk_size = 10000
        max_rows = int(100000 * sample_frac)  # Limit total rows
        
        for i, chunk in enumerate(pd.read_csv(file_path, chunksize=chunk_size)):
            if i * chunk_size >= max_rows:
                break
            chunks.append(chunk)
        
        df = pd.concat(chunks, ignore_index=True)
        print(f"Loaded {len(df)} samples from {dataset_name}")
        
        # Process features
        features, labels = self.feature_processor.process_dataset(df, dataset_name, label_col)
        
        # Map labels to unified taxonomy
        unified_labels = [self.taxonomy.map_label(label) for label in labels]
        
        # Encode labels
        encoded_labels = self.label_encoder.fit_transform(unified_labels)
        
        return features.values, encoded_labels, unified_labels
    
    def load_all_datasets(self, dataset_paths, sample_frac=0.1):
        """Load all three datasets"""
        all_features = []
        all_labels = []
        all_dataset_ids = []
        
        for i, (name, path) in enumerate(dataset_paths.items()):
            if not os.path.exists(path):
                print(f"Warning: {path} not found, skipping {name}")
                continue
            
            features, labels, _ = self.load_dataset(path, name, sample_frac)
            
            all_features.append(features)
            all_labels.append(labels)
            all_dataset_ids.extend([i] * len(labels))
            
            print(f"{name}: {features.shape[0]} samples, {len(np.unique(labels))} classes")
        
        # Combine all datasets
        if all_features:
            X = np.vstack(all_features)
            y = np.hstack(all_labels)
            dataset_ids = np.array(all_dataset_ids)
            
            # Scale features
            X = self.feature_processor.scaler.fit_transform(X)
            
            print(f"\nCombined dataset: {X.shape[0]} samples, {X.shape[1]} features")
            print(f"Class distribution: {dict(zip(*np.unique(y, return_counts=True)))}")
            
            return X, y, dataset_ids
        
        return None, None, None


# ============================================================================
# ENHANCED MAMBASHIELD FOR MULTI-DATASET
# ============================================================================

class MultiDatasetMambaShield(nn.Module):
    """MambaShield with multi-dataset support"""
    
    def __init__(self, input_dim, num_classes, num_datasets=3, config=None):
        super().__init__()
        
        if config is None:
            config = {
                'hidden_dim': 128,
                'd_state': 16,
                'n_layers': 3,
                'dropout': 0.1,
                'use_dataset_embedding': True
            }
        
        self.config = config
        self.num_datasets = num_datasets
        
        # Dataset embeddings for multi-modal learning
        if config['use_dataset_embedding']:
            self.dataset_embedding = nn.Embedding(num_datasets, 32)
            input_dim += 32
        
        # Input projection
        self.input_proj = nn.Sequential(
            nn.Linear(input_dim, config['hidden_dim']),
            nn.LayerNorm(config['hidden_dim']),
            nn.ReLU(),
            nn.Dropout(config['dropout'])
        )
        
        # Mamba blocks (simplified for stability)
        self.mamba_blocks = nn.ModuleList([
            self._create_mamba_block(config['hidden_dim'], config['d_state'])
            for _ in range(config['n_layers'])
        ])
        
        # Multi-head attention for temporal fusion
        self.temporal_attention = nn.MultiheadAttention(
            config['hidden_dim'], 
            num_heads=8,
            dropout=config['dropout'],
            batch_first=True
        )
        
        # Classification heads (one per dataset + unified)
        self.dataset_heads = nn.ModuleList([
            nn.Linear(config['hidden_dim'], config['hidden_dim'] // 2)
            for _ in range(num_datasets)
        ])
        
        self.unified_classifier = nn.Sequential(
            nn.LayerNorm(config['hidden_dim']),
            nn.Linear(config['hidden_dim'], config['hidden_dim'] // 2),
            nn.ReLU(),
            nn.Dropout(config['dropout']),
            nn.Linear(config['hidden_dim'] // 2, num_classes)
        )
        
    def _create_mamba_block(self, hidden_dim, d_state):
        """Create simplified Mamba block"""
        return nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.GLU(dim=-1),
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1, groups=hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim)
        )
    
    def forward(self, x, dataset_ids=None):
        """Forward pass with multi-dataset support"""
        B, L, D = x.shape
        
        # Add dataset embeddings if available
        if dataset_ids is not None and self.config['use_dataset_embedding']:
            dataset_emb = self.dataset_embedding(dataset_ids)
            dataset_emb = dataset_emb.unsqueeze(1).expand(-1, L, -1)
            x = torch.cat([x, dataset_emb], dim=-1)
        
        # Input projection
        x = self.input_proj(x)
        
        # Apply Mamba blocks with residual connections
        for mamba_block in self.mamba_blocks:
            residual = x
            
            # Process through Mamba-like block
            if len(x.shape) == 3:
                # Reshape for Conv1d
                x_conv = x.transpose(1, 2)  # (B, D, L)
                for layer in mamba_block:
                    if isinstance(layer, nn.Conv1d):
                        x_conv = layer(x_conv)
                    elif isinstance(layer, nn.Linear) or isinstance(layer, nn.LayerNorm):
                        x_conv = x_conv.transpose(1, 2)  # Back to (B, L, D)
                        x_conv = layer(x_conv)
                        if not isinstance(layer, nn.LayerNorm):
                            x_conv = x_conv.transpose(1, 2)  # Back to (B, D, L) for next conv
                    elif isinstance(layer, nn.GLU):
                        x_conv = x_conv.transpose(1, 2)
                        x_conv = layer(x_conv)
                        x_conv = x_conv.transpose(1, 2)
                
                if len(x_conv.shape) == 3 and x_conv.shape[1] != L:
                    x_conv = x_conv.transpose(1, 2)
                
                x = x_conv
            
            # Residual connection
            x = x + residual
        
        # Temporal attention
        x_attn, _ = self.temporal_attention(x, x, x)
        x = x + x_attn
        
        # Use last timestep for classification
        features = x[:, -1, :]
        
        # Classification
        logits = self.unified_classifier(features)
        
        outputs = {
            'logits': logits,
            'features': features
        }
        
        # Dataset-specific heads if needed
        if dataset_ids is not None:
            dataset_outputs = []
            for i, head in enumerate(self.dataset_heads):
                mask = (dataset_ids == i)
                if mask.any():
                    dataset_features = features[mask]
                    dataset_out = head(dataset_features)
                    dataset_outputs.append(dataset_out)
            
            if dataset_outputs:
                outputs['dataset_specific'] = dataset_outputs
        
        return outputs


# ============================================================================
# COMPREHENSIVE TRAINING PIPELINE
# ============================================================================

def train_multimodal_mambashield():
    """Train MambaShield on all three datasets"""
    
    print("="*60)
    print("MAMBASHIELD MULTI-DATASET TRAINING")
    print("="*60)
    
    # Configuration
    config = {
        'seq_len': 10,
        'batch_size': 32,
        'hidden_dim': 128,
        'd_state': 16,
        'n_layers': 3,
        'dropout': 0.1,
        'lr': 1e-4,
        'epochs': 10,
        'sample_frac': 0.05,  # Use 5% of each dataset for demo
        'use_dataset_embedding': True
    }
    
    # Dataset paths
    dataset_paths = {
        'cic': '/kaggle/input/poisoning-i/CIC_IoT_M3.csv',
        'cse': '/kaggle/input/poisoning-i/CSE-CIC_2018.csv',
        'ton': '/kaggle/input/poisoning-i/UNSW_TON_IoT.csv'
    }
    
    # Device setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Load all datasets
    print("\n1. Loading Datasets...")
    loader = MultiDatasetLoader(config)
    X, y, dataset_ids = loader.load_all_datasets(dataset_paths, config['sample_frac'])
    
    if X is None:
        print("Failed to load datasets")
        return
    
    # Create sequences
    print("\n2. Creating Temporal Sequences...")
    sequences = []
    labels = []
    ds_ids = []
    
    for i in range(len(X) - config['seq_len'] + 1):
        sequences.append(X[i:i+config['seq_len']])
        labels.append(y[i+config['seq_len']-1])
        ds_ids.append(dataset_ids[i])
    
    X_seq = np.array(sequences)
    y_seq = np.array(labels)
    ds_ids = np.array(ds_ids)
    
    # Split data
    indices = np.arange(len(X_seq))
    train_idx, test_idx = train_test_split(indices, test_size=0.2, random_state=42)
    train_idx, val_idx = train_test_split(train_idx, test_size=0.2, random_state=42)
    
    # Create datasets
    train_dataset = TensorDataset(
        torch.FloatTensor(X_seq[train_idx]),
        torch.LongTensor(y_seq[train_idx]),
        torch.LongTensor(ds_ids[train_idx])
    )
    val_dataset = TensorDataset(
        torch.FloatTensor(X_seq[val_idx]),
        torch.LongTensor(y_seq[val_idx]),
        torch.LongTensor(ds_ids[val_idx])
    )
    test_dataset = TensorDataset(
        torch.FloatTensor(X_seq[test_idx]),
        torch.LongTensor(y_seq[test_idx]),
        torch.LongTensor(ds_ids[test_idx])
    )
    
    # Dataloaders
    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config['batch_size'])
    test_loader = DataLoader(test_dataset, batch_size=config['batch_size'])
    
    # Create model
    print("\n3. Creating Multi-Dataset MambaShield Model...")
    num_classes = len(np.unique(y))
    num_datasets = len(np.unique(dataset_ids))
    
    model = MultiDatasetMambaShield(
        input_dim=X.shape[1],
        num_classes=num_classes,
        num_datasets=num_datasets,
        config=config
    ).to(device)
    
    print(f"Model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
    
    # Training setup
    optimizer = optim.AdamW(model.parameters(), lr=config['lr'])
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['epochs'])
    criterion = nn.CrossEntropyLoss()
    scaler = GradScaler()
    
    # Training history
    history = defaultdict(list)
    best_val_acc = 0
    
    # Training loop
    print("\n4. Training...")
    for epoch in range(config['epochs']):
        # Training
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0
        
        for data, labels, ds_ids in tqdm(train_loader, desc=f'Epoch {epoch+1}/{config["epochs"]}'):
            data = data.to(device)
            labels = labels.to(device)
            ds_ids = ds_ids.to(device)
            
            optimizer.zero_grad()
            
            with autocast():
                outputs = model(data, ds_ids)
                loss = criterion(outputs['logits'], labels)
            
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            
            train_loss += loss.item()
            pred = outputs['logits'].argmax(dim=1)
            train_correct += (pred == labels).sum().item()
            train_total += labels.size(0)
        
        train_acc = 100 * train_correct / train_total
        
        # Validation
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for data, labels, ds_ids in val_loader:
                data = data.to(device)
                labels = labels.to(device)
                ds_ids = ds_ids.to(device)
                
                outputs = model(data, ds_ids)
                loss = criterion(outputs['logits'], labels)
                
                val_loss += loss.item()
                pred = outputs['logits'].argmax(dim=1)
                val_correct += (pred == labels).sum().item()
                val_total += labels.size(0)
        
        val_acc = 100 * val_correct / val_total
        
        # Update scheduler
        scheduler.step()
        
        # Save history
        history['train_loss'].append(train_loss / len(train_loader))
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss / len(val_loader))
        history['val_acc'].append(val_acc)
        
        print(f"Epoch {epoch+1}: Train Acc={train_acc:.2f}%, Val Acc={val_acc:.2f}%")
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = model.state_dict()
    
    # Load best model
    model.load_state_dict(best_model_state)
    
    # Test evaluation
    print("\n5. Testing on Combined Test Set...")
    model.eval()
    
    # Per-dataset results
    dataset_names = ['CIC-IoT-2023', 'CSE-CICIDS2018', 'UNSW-NB15']
    all_preds = []
    all_labels = []
    all_ds_ids = []
    
    with torch.no_grad():
        for data, labels, ds_ids in test_loader:
            data = data.to(device)
            outputs = model(data, ds_ids.to(device))
            preds = outputs['logits'].argmax(dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())
            all_ds_ids.extend(ds_ids.numpy())
    
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_ds_ids = np.array(all_ds_ids)
    
    # Overall metrics
    overall_acc = accuracy_score(all_labels, all_preds)
    overall_precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
    overall_recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
    overall_f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
    
    print("\n" + "="*60)
    print("FINAL RESULTS")
    print("="*60)
    print(f"Overall Accuracy: {overall_acc*100:.2f}%")
    print(f"Overall Precision: {overall_precision:.4f}")
    print(f"Overall Recall: {overall_recall:.4f}")
    print(f"Overall F1-Score: {overall_f1:.4f}")
    
    # Per-dataset metrics
    print("\nPer-Dataset Performance:")
    for i, name in enumerate(dataset_names[:num_datasets]):
        mask = all_ds_ids == i
        if mask.any():
            ds_acc = accuracy_score(all_labels[mask], all_preds[mask])
            ds_f1 = f1_score(all_labels[mask], all_preds[mask], average='weighted', zero_division=0)
            print(f"  {name}: Accuracy={ds_acc*100:.2f}%, F1={ds_f1:.4f}")
    
    # Attack category performance
    print("\nPer-Attack Category Performance:")
    taxonomy = UnifiedTaxonomy()
    category_names = list(taxonomy.taxonomy.keys())
    
    for i, category in enumerate(category_names[:num_classes]):
        mask = all_labels == i
        if mask.any():
            cat_precision = precision_score(all_labels[mask], all_preds[mask], average='binary', pos_label=i, zero_division=0)
            cat_recall = recall_score(all_labels[mask], all_preds[mask], average='binary', pos_label=i, zero_division=0)
            print(f"  {category}: Precision={cat_precision:.4f}, Recall={cat_recall:.4f}")
    
    # Plot results
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    # Training history
    axes[0, 0].plot(history['train_loss'], label='Train')
    axes[0, 0].plot(history['val_loss'], label='Val')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Training Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True)
    
    axes[0, 1].plot(history['train_acc'], label='Train')
    axes[0, 1].plot(history['val_acc'], label='Val')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy (%)')
    axes[0, 1].set_title('Training Accuracy')
    axes[0, 1].legend()
    axes[0, 1].grid(True)
    
    # Confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[1, 0])
    axes[1, 0].set_xlabel('Predicted')
    axes[1, 0].set_ylabel('Actual')
    axes[1, 0].set_title('Confusion Matrix')
    
    # Per-dataset accuracy bar chart
    dataset_accs = []
    for i in range(num_datasets):
        mask = all_ds_ids == i
        if mask.any():
            dataset_accs.append(accuracy_score(all_labels[mask], all_preds[mask]) * 100)
    
    axes[1, 1].bar(dataset_names[:len(dataset_accs)], dataset_accs)
    axes[1, 1].set_ylabel('Accuracy (%)')
    axes[1, 1].set_title('Per-Dataset Accuracy')
    axes[1, 1].set_ylim([0, 100])
    
    for i, v in enumerate(dataset_accs):
        axes[1, 1].text(i, v + 1, f'{v:.1f}%', ha='center')
    
    plt.tight_layout()
    plt.show()
    
    return model, history

if __name__ == "__main__":
    model, history = train_multimodal_mambashield() 

