# -*- coding: utf-8 -*-
"""
Integrated DQN-Actor-Critic Architecture for Poisoning Attack Detection
Author: Your Name
Date: 2024
"""

"""# Cell 1: Imports and Mathematical Framework

## Mathematical Framework for DQN-Actor-Critic Based Poisoning Detection

1. Feature Space Definition:
   F = {f ∈ ℝᵈ | f is a network flow feature vector}
   where d is the dimension of extracted features

2. Data Distribution:
   P(x) = Normal network traffic distribution
   Q(x) = Poisoned traffic distribution
   KL(P||Q) = Measure of distribution divergence

3. Attack Space:
   A = {a | a is a poisoning attack vector}
   Impact(a) = Σ ||f_original - f_poisoned||₂

4. Detection Functions:
   DQN: Q(s,a) = 𝔼[R + γ max Q(s',a') | s,a]
   Actor: π(a|s) = P(action=a | state=s)
   Critic: V(s) = 𝔼[Σ γᵗR_t | s₀=s]

5. Combined Detection Score:
   D(x) = α·Q(x) + (1-α)·π(x) + β·KL(P||Q_x)
   where Q_x is the estimated distribution at point x
"""

## GPU verification

In [None]:
def check_gpu():
    """Check and setup GPU if available"""
    if torch.cuda.is_available():
        # Print GPU info
        print(f"CUDA is available:")
        print(f"- GPU Device: {torch.cuda.get_device_name(0)}")
        print(f"- CUDA Version: {torch.version.cuda}")
        print(f"- Memory Available: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f}GB")

        # Set device
        device = torch.device("cuda:0")

        # Enable optimizations
        torch.backends.cudnn.benchmark = True
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

        return device
    else:
        print("No GPU detected. Using CPU.")
        return torch.device("cpu")

## Installs

In [2]:
# Usage in notebook
!pip install torch_xla
!pip install cloud-tpu-client
!pip install cloud-tpu-client==0.10 torch_xla==2.0 torch==2.0.0 torchvision==0.15.1 -f https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp310-cp310-linux_x86_64.whl


[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


## Importing Liberaries

In [None]:
import os
os.environ['XLA_USE_BF16'] = "1"
os.environ['XLA_TENSOR_ALLOCATOR_MAXSIZE'] = '100000000'

import requests
requests.get('http://metadata.google.internal/computeMetadata/v1/instance/name', headers={'Metadata-Flavor': 'Google'})

In [None]:
# Standard imports

import gc
import time
import random
import logging
import traceback
from typing import Dict, List, Tuple
from collections import deque, defaultdict

# Data processing
import numpy as np
import pandas as pd
import scipy.stats
from sklearn.model_selection import train_test_split

# Deep Learning
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

# Machine Learning
from sklearn.preprocessing import LabelEncoder, StandardScaler
from tensorflow.keras.utils import Sequence
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from typing import Dict, Any, List, Union

# Utilities
import psutil
from tqdm import tqdm
import contextlib
from torch.nn.functional import sigmoid

import torch.cuda
import torch.backends.cudnn

import tensorflow as tf
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp



E0000 00:00:1732057941.579875      38 common_lib.cc:798] Could not set metric server port: INVALID_ARGUMENT: Could not find SliceBuilder port 8471 in any of the 0 ports provided in `tpu_process_addresses`="local"
=== Source Location Trace: ===
learning/45eac/tfrc/runtime/common_lib.cc:479
D1119 23:12:21.587977505      38 config.cc:196]                        gRPC EXPERIMENT call_status_override_on_cancellation   OFF (default:OFF)
D1119 23:12:21.587992022      38 config.cc:196]                        gRPC EXPERIMENT call_v3                                OFF (default:OFF)
D1119 23:12:21.587995420      38 config.cc:196]                        gRPC EXPERIMENT canary_client_privacy                  ON  (default:ON)
D1119 23:12:21.587997840      38 config.cc:196]                        gRPC EXPERIMENT capture_base_context                   ON  (default:ON)
D1119 23:12:21.588000225      38 config.cc:196]                        gRPC EXPERIMENT client_idleness                        ON  (defau

## Add TPU call

In [None]:
def setup_tpu_colab():
    """Setup TPU for Kaggle/Colab"""
    import requests
    from requests.exceptions import ConnectionError
    
    # Check if TPU is available
    try:
        # Get TPU core count
        tpu_env = os.environ.get('TPU_NAME')
        if tpu_env:
            url = f'http://{os.environ["COLAB_TPU_ADDR"]}/requestversion/tpu_worker_state'
            timeout = 60  # seconds
            try:
                response = requests.get(url, timeout=timeout)
                if response.ok:
                    print("TPU available and responding")
                    return True
            except ConnectionError:
                print("Failed to connect to TPU")
                return False
    except Exception as e:
        print(f"Error checking TPU: {str(e)}")
        return False
    
    return False

def initialize_tpu():
    """Initialize TPU with error handling"""
    try:
        # Clear any previous TPU memory
        if hasattr(torch_xla, 'core'):
            torch_xla.core.xla_model.clear_replicated()
        
        # Get TPU device
        device = xm.xla_device()
        print(f"TPU Device initialized: {device}")
        
        # Test device
        test_tensor = torch.randn(2, 2).to(device)
        print("TPU test successful")
        
        return device
    except Exception as e:
        print(f"Failed to initialize TPU: {str(e)}")
        print("Falling back to CPU/GPU")
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")


## Label Handling

In [None]:
# At start of notebook
!nvidia-smi
print(f"CUDA Available: {torch.cuda.is_available()}")

In [None]:
class LabelHandler:
    """Handle both binary and multi-class labels for poisoning detection"""
    def __init__(self):
        self.label_mapping = {}  # Store original label meanings
        self.attack_types = {}   # Store attack type categories
        self.binary_mapping = {} # Map between binary and multi-class

    def process_labels(self, labels: np.ndarray, label_names: List[str] = None) -> Tuple[np.ndarray, np.ndarray]:
        """Process labels to maintain both binary and multi-class information

        Args:
            labels: Original multi-class labels
            label_names: Optional list of label names/descriptions

        Returns:
            Tuple of (binary_labels, multi_class_labels)
        """
        unique_labels = np.unique(labels)

        # Store mapping if not already created
        if not self.label_mapping:
            self.label_mapping = {
                idx: name if label_names is not None and idx < len(label_names) else f"Class_{idx}"
                for idx in unique_labels
            }

            # Create binary mapping (0 for normal, 1 for any attack)
            self.binary_mapping = {
                idx: 0 if idx == 0 else 1  # Assuming 0 is normal traffic
                for idx in unique_labels
            }

            # Store attack types separately
            self.attack_types = {
                idx: label
                for idx, label in self.label_mapping.items()
                if idx != 0  # Exclude normal traffic
            }

        # Create both label versions
        binary_labels = np.array([self.binary_mapping[l] for l in labels])
        multi_labels = labels.copy()

        return binary_labels, multi_labels

    def get_attack_info(self, attack_id: int) -> Dict[str, Union[str, bool, int]]:
        """Get information about a specific attack type"""
        if attack_id not in self.label_mapping:
            return {
                'attack_name': f'Unknown_Attack_{attack_id}',
                'is_attack': True,
                'attack_id': attack_id,
                'binary_class': self.binary_mapping.get(attack_id, 1)
            }

        return {
            'attack_name': self.label_mapping[attack_id],
            'is_attack': self.binary_mapping.get(attack_id, 1) == 1,
            'attack_id': attack_id,
            'binary_class': self.binary_mapping.get(attack_id, 1)
        }

    def get_attack_stats(self, multi_labels: np.ndarray) -> Dict[str, Dict[str, Union[int, float, bool]]]:
        """Get statistics about attack distribution"""
        unique, counts = np.unique(multi_labels, return_counts=True)
        total = len(multi_labels)

        stats = {}
        for u, c in zip(unique, counts):
            attack_info = self.get_attack_info(u)
            stats[attack_info['attack_name']] = {
                'count': int(c),
                'percentage': float(c/total),
                'is_attack': attack_info['is_attack']
            }

        return stats

    def print_distribution(self, labels: np.ndarray):
        """Print distribution of attacks in dataset"""
        stats = self.get_attack_stats(labels)

        print("\nAttack Distribution:")
        print("-" * 50)
        print(f"{'Attack Type':<30} {'Count':>8} {'Percentage':>12}")
        print("-" * 50)

        for attack_name, info in stats.items():
            print(f"{attack_name:<30} {info['count']:>8} {info['percentage']:>11.2f}%")

    def __str__(self) -> str:
        return f"LabelHandler with {len(self.attack_types)} attack types"

    def __repr__(self) -> str:
        return f"LabelHandler(n_attacks={len(self.attack_types)}, n_labels={len(self.label_mapping)})"



## System setup

In [None]:
# System setup function
def setup_system():
    """Setup system checks and cleanup before training"""
    print("\n=== System Setup ===")

    # Clear memory
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print(f"Using GPU: {torch.cuda.get_device_name(0)}")
        print(f"CUDA Device: {torch.cuda.get_device_name(0)}")
        print(f"CUDA Memory: {torch.cuda.get_device_properties(0).total_memory/1e9:.2f}GB")
        print(f"CUDA Capability: {torch.cuda.get_device_capability()}")
    else:
        print("Using CPU for training")

    # Check available memory
    process = psutil.Process()
    print(f"Initial memory usage: {process.memory_info().rss/1024/1024:.2f}MB")

    # Initialize logging
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler('training.log'),
            logging.StreamHandler()
        ]
    )

    # Verify system requirements
    memory_gb = psutil.virtual_memory().total / (1024**3)
    if memory_gb < 8:  # Minimum 8GB required
        raise RuntimeError(f"Insufficient memory: {memory_gb:.1f}GB < 8GB required")

    # Set random seeds for reproducibility
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(42)

    print("System setup completed successfully")
    return True

# Memory monitoring class
class MemoryMonitor:
    """Monitor system memory usage during training"""
    def __init__(self, log_interval=60):  # Log every 60 seconds
        self.log_interval = log_interval
        self.last_log_time = time.time()

        # Setup logging
        logging.basicConfig(
            filename='memory_usage.log',
            level=logging.INFO,
            format='%(asctime)s - %(message)s'
        )

    def check_memory(self):
        current_time = time.time()
        if current_time - self.last_log_time >= self.log_interval:
            process = psutil.Process()
            memory_info = process.memory_info()
            system_memory = psutil.virtual_memory()

            logging.info(
                f"Memory Usage - RSS: {memory_info.rss/1024/1024:.2f}MB, "
                f"VMS: {memory_info.vms/1024/1024:.2f}MB, "
                f"System Memory Used: {system_memory.percent}%"
            )

            self.last_log_time = current_time
            return system_memory.percent > 90  # Warning threshold
        return False

class GPUMemoryManager:
    @staticmethod
    def print_memory_stats():
        """Print memory usage for either GPU or CPU"""
        if torch.cuda.is_available():
            print("\nGPU Memory Usage:")
            print(f"Allocated: {torch.cuda.memory_allocated()/1e9:.2f}GB")
            print(f"Cached: {torch.cuda.memory_reserved()/1e9:.2f}GB")
        else:
            process = psutil.Process()
            print("\nCPU Memory Usage:")
            print(f"RSS: {process.memory_info().rss/1e9:.2f}GB")
            print(f"VMS: {process.memory_info().vms/1e9:.2f}GB")

    @staticmethod
    def clear_memory():
        """Clear memory cache"""
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    @staticmethod
    def get_memory_usage():
        """Get current memory usage as percentage"""
        if torch.cuda.is_available():
            return torch.cuda.memory_allocated() / torch.cuda.get_device_properties(0).total_memory
        else:
            process = psutil.Process()
            return process.memory_percent()



## Model configuration

In [None]:
class ModelConfig:
    def __init__(self):
        # Check TPU availability
        if setup_tpu_colab():
            self.device = initialize_tpu()
            self.use_tpu = True
            # TPU-specific settings
            self.batch_size = 1024  # TPU prefers larger batches
            self.num_workers = 8
            self.use_amp = False  # TPU has its own optimization
        else:
            self.use_tpu = False
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            # Regular settings
            self.batch_size = 512 if torch.cuda.is_available() else 128
            self.num_workers = 4 if torch.cuda.is_available() else 2
            self.use_amp = torch.cuda.is_available()
        
        print(f"Using device: {self.device}")
        print(f"Batch size: {self.batch_size}")
        
    """Configuration for the DQN-Actor-Critic model and training process"""
    def __init__(self):
        # Force CUDA device if available
        if torch.cuda.is_available():
            print(f"CUDA is available. Using GPU: {torch.cuda.get_device_name(0)}")
            self.device = torch.device("cuda:0")
        else:
            print("CUDA is not available. Using CPU.")
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Model Architecture
        self.feature_dim = 128
        self.hidden_dim = 256
        self.num_actions = 2
        self.num_heads = 4  # For attention mechanism
        self.dropout_rate = 0.2

        # Training Parameters
        self.learning_rate = 1e-4
        self.batch_size = 256
        self.num_epochs = 100
        self.sequence_length = 10
        self.gamma = 0.99  # Discount factor

        # DQN Specific
        self.epsilon_start = 1.0
        self.epsilon_end = 0.01
        self.epsilon_decay = 0.995

        # Actor-Critic Specific
        self.value_loss_coef = 0.5
        self.entropy_coef = 0.01

        # Integration Parameters
        self.dqn_weight = 0.6  # Weight for DQN in combined decisions
        self.ac_weight = 0.4   # Weight for Actor-Critic in combined decisions

        # Memory and Buffer
        self.replay_buffer_size = 10000
        self.min_replay_size = 1000

        # Optimization
        self.gradient_clip = 1.0
        self.warmup_steps = 1000
        self.target_update_freq = 10

        # Early Stopping
        self.patience = 5
        self.min_delta = 0.001

        # Directories
        self.checkpoint_dir = 'checkpoints'
        self.log_dir = 'logs'

        # Device Specific Optimizations
        self._setup_device_specific()

    def _setup_device_specific(self):
        """Setup device-specific optimizations"""
        if self.device.type == "cuda":
            # GPU settings
            self.use_amp = True  # Enable automatic mixed precision
            self.batch_size = 512  # Larger batch size for GPU
            self.num_workers = 4

            # Enable TF32 for better performance on Ampere GPUs
            torch.backends.cuda.matmul.allow_tf32 = True
            torch.backends.cudnn.allow_tf32 = True
            torch.backends.cudnn.benchmark = True  # Enable cudnn autotuner
        else:

            # CPU settings
            self.use_amp = False
            self.batch_size = 128
            self.num_workers = min(2, os.cpu_count())
            self.pin_memory = False

        # Enable Intel MKL optimizations if available
        if hasattr(torch, 'set_num_threads'):
            torch.set_num_threads(self.num_workers)
        if hasattr(torch, 'set_num_interop_threads'):
            torch.set_num_interop_threads(self.num_workers)

    def print_config(self):
        """Print the current configuration"""
        print("\nModel Configuration:")
        print(f"- Device: {self.device}")
        print(f"- Feature Dimension: {self.feature_dim}")
        print(f"- Hidden Dimension: {self.hidden_dim}")
        print(f"- Batch Size: {self.batch_size}")
        print(f"- Learning Rate: {self.learning_rate}")
        print(f"- Number of Epochs: {self.num_epochs}")
        print(f"- DQN Weight: {self.dqn_weight}")
        print(f"- Actor-Critic Weight: {self.ac_weight}")
        if self.device.type == "cuda":
            print(f"- AMP Enabled: {self.use_amp}")
            print(f"- CUDA Capability: {torch.cuda.get_device_capability()}")




## Data Loading and Preprocessing

In [None]:
## DATASET SPECIFIC PROCESSOR

class DatasetSpecificProcessor:
    """Processes features specific to each dataset type"""
    def __init__(self, dataset_type: str):
        self.dataset_type = dataset_type.lower()
        self.protocol_features = {}
        self.temporal_windows = {}
        self.flow_statistics = {}

    def process_features(self, data: pd.DataFrame) -> pd.DataFrame:
        """Process features based on dataset type"""
        # Ensure data is properly formatted as DataFrame with string column names
        if not isinstance(data, pd.DataFrame):
            data = pd.DataFrame(data)

        # Convert numeric column indices to string names if needed
        if all(isinstance(col, int) for col in data.columns):
            data.columns = [f'feature_{i}' for i in range(len(data.columns))]

        if self.dataset_type == 'cic':
            return self._process_cic_features(data)
        elif self.dataset_type == 'ton':
            return self._process_ton_features(data)
        elif self.dataset_type == 'cse':
            return self._process_cse_features(data)
        else:
            raise ValueError(f"Unknown dataset type: {self.dataset_type}")

    def _process_cic_features(self, data: pd.DataFrame) -> pd.DataFrame:
        """Process CIC-IoT specific features"""
        processed_data = data.copy()

        # Find protocol-related columns (case insensitive)
        protocol_columns = [
            col for col in processed_data.columns
            if isinstance(col, str) and 'protocol' in col.lower()
        ]

        # Process protocol features if any exist
        if protocol_columns:
            # One-hot encoding for protocol columns
            processed_data = pd.get_dummies(
                processed_data,
                columns=protocol_columns,
                prefix=['protocol']
            )

        # Calculate packet and connection rates for numeric columns
        numeric_cols = processed_data.select_dtypes(include=[np.number]).columns
        if 'packet_count' in numeric_cols and 'duration' in numeric_cols:
            processed_data['packet_rate'] = (
                processed_data['packet_count'] /
                processed_data['duration'].clip(lower=1e-6)
            )

        if 'connection_count' in numeric_cols and 'duration' in numeric_cols:
            processed_data['connection_rate'] = (
                processed_data['connection_count'] /
                processed_data['duration'].clip(lower=1e-6)
            )

        return processed_data

    def _process_ton_features(self, data: pd.DataFrame) -> pd.DataFrame:
        """Process TON-IoT specific features"""
        processed_data = data.copy()
        window_size = 10

        # Process numeric features
        numeric_cols = processed_data.select_dtypes(include=[np.number]).columns
        for col in numeric_cols:
            # Calculate temporal mean and variance safely
            try:
                processed_data[f'{col}_temporal_mean'] = processed_data[col].rolling(
                    window=window_size, min_periods=1
                ).mean()
                processed_data[f'{col}_temporal_var'] = processed_data[col].rolling(
                    window=window_size, min_periods=1
                ).var()
            except Exception as e:
                print(f"Warning: Could not process column {col}: {str(e)}")

        # Find and process service-related columns
        service_columns = [
            col for col in processed_data.columns
            if isinstance(col, str) and 'service' in col.lower()
        ]

        if service_columns:
            for service in service_columns:
                try:
                    # Convert to numeric if needed
                    if not pd.api.types.is_numeric_dtype(processed_data[service]):
                        processed_data[service] = pd.to_numeric(
                            processed_data[service], errors='coerce'
                        )
                    # Calculate service rate
                    processed_data[f'{service}_rate'] = processed_data[service].rolling(
                        window=window_size, min_periods=1
                    ).mean()
                except Exception as e:
                    print(f"Warning: Could not process service column {service}: {str(e)}")

        return processed_data

    def _process_cse_features(self, data: pd.DataFrame) -> pd.DataFrame:
        """Process CSE-CIC specific features"""
        processed_data = data.copy()
        window_size = 10

        # Process numeric features only
        numeric_cols = processed_data.select_dtypes(include=[np.number]).columns

        # Process flow-based features
        if 'bytes_transferred' in numeric_cols and 'duration' in numeric_cols:
            processed_data['flow_byte_rate'] = (
                processed_data['bytes_transferred'] /
                processed_data['duration'].clip(lower=1e-6)
            )

        # Process packet statistics
        if 'packet_size' in numeric_cols:
            processed_data['packet_size_mean'] = processed_data['packet_size'].rolling(
                window=window_size, min_periods=1
            ).mean()
            processed_data['packet_size_std'] = processed_data['packet_size'].rolling(
                window=window_size, min_periods=1
            ).std()

        # Process flow statistics
        if 'flow_duration' in numeric_cols:
            processed_data['flow_rate'] = 1.0 / processed_data['flow_duration'].clip(lower=1e-6)
            processed_data['flow_rate_mean'] = processed_data['flow_rate'].rolling(
                window=window_size, min_periods=1
            ).mean()

        return processed_data


class PoisoningFeatureExtractor:
    """Extracts poisoning-specific features"""
    def __init__(self, window_size: int = 100):
        self.window_size = window_size
        self.feature_history = deque(maxlen=window_size)
        self.distribution_history = deque(maxlen=window_size)

    def extract_poisoning_features(self, data: pd.DataFrame) -> Dict[str, np.ndarray]:
        """Extract all poisoning-specific features"""
        # Ensure data is properly formatted
        if not isinstance(data, pd.DataFrame):
            data = pd.DataFrame(data)

        features = {}

        try:
            # Extract distribution drift features
            features.update(self._detect_distribution_drift(data))

            # Extract temporal consistency features
            features.update(self._check_temporal_consistency(data))

            # Extract protocol behavior features
            features.update(self._analyze_protocol_behavior(data))

            # Extract traffic pattern anomaly features
            features.update(self._detect_traffic_anomalies(data))

        except Exception as e:
            print(f"Warning: Error extracting poisoning features: {str(e)}")
            # Return default features if extraction fails
            features = {
                'distribution_drift': np.array([0.0]),
                'temporal_consistency': np.array([0.0]),
                'protocol_frequency': np.array([0.0]),
                'traffic_anomaly': np.array([0.0])
            }

        return features

    def _detect_distribution_drift(self, data: pd.DataFrame) -> Dict[str, float]:
        """Detect distribution shifts with improved stability"""
        numeric_data = data.select_dtypes(include=[np.number])
        eps = 1e-8  # Numerical stability constant

        if len(self.feature_history) > 0:
            try:
                # Compute means with stability
                previous_mean = np.nanmean([x.mean() for x in self.feature_history], axis=0)
                current_mean = numeric_data.mean()

                # Compute covariance with stability
                previous_data = np.hstack([x.values for x in self.feature_history])
                previous_cov = np.cov(previous_data.T) + eps * np.eye(previous_data.shape[1])
                current_cov = numeric_data.cov() + eps * np.eye(numeric_data.shape[1])

                # Calculate drift score
                mean_diff = np.linalg.norm(current_mean - previous_mean)
                cov_diff = np.linalg.norm(current_cov - previous_cov, ord='fro')
                drift_score = (mean_diff + cov_diff) / (1 + eps)

            except Exception as e:
                print(f"Warning: Error computing drift score: {str(e)}")
                drift_score = 0.0
        else:
            drift_score = 0.0

        self.feature_history.append(numeric_data)
        return {'distribution_drift': drift_score}


    def _check_temporal_consistency(self, data: pd.DataFrame) -> Dict[str, float]:
        """Check temporal consistency of features"""
        numeric_data = data.select_dtypes(include=[np.number])
        consistency_score = 0.0

        if len(self.feature_history) > 0:
            previous_data = self.feature_history[-1]
            common_cols = set(numeric_data.columns) & set(previous_data.columns)

            if common_cols:
                consistency_score = np.mean([
                    np.abs(numeric_data[col] - previous_data[col]).mean()
                    for col in common_cols
                ])

        return {'temporal_consistency': consistency_score}

    def _analyze_protocol_behavior(self, data: pd.DataFrame) -> Dict[str, np.ndarray]:
        """Analyze protocol behavior patterns"""
        protocol_columns = [
            col for col in data.columns
            if isinstance(col, str) and 'protocol' in col.lower()
        ]

        if protocol_columns:
            protocol_freqs = data[protocol_columns].mean()
            return {'protocol_frequency': protocol_freqs.values}
        return {'protocol_frequency': np.array([0.0])}

    def _detect_traffic_anomalies(self, data: pd.DataFrame) -> Dict[str, np.ndarray]:
        """Detect anomalies in traffic patterns"""
        numeric_data = data.select_dtypes(include=[np.number])
        if 'packet_count' in numeric_data.columns:
            expected_count = numeric_data['packet_count'].mean()
            observed_counts = numeric_data['packet_count']

            anomaly_scores = np.abs(
                (observed_counts - expected_count) /
                np.maximum(expected_count, 1e-6)
            )
            return {'traffic_anomaly': anomaly_scores.values}
        return {'traffic_anomaly': np.array([0.0])}


class PoisoningGenerator:
    """Generates synthetic poisoning samples"""
    def __init__(self, epsilon: float = 0.1, flip_rate: float = 0.1):
        self.epsilon = epsilon
        self.flip_rate = flip_rate

    def generate_gradient_based_poisoning(self, data: torch.Tensor, loss_fn: callable) -> torch.Tensor:
        """Generate gradient-based poisoning samples"""
        data.requires_grad = True
        loss = loss_fn(data)
        gradient = torch.autograd.grad(loss, data)[0]
        poisoned_data = data + self.epsilon * torch.sign(gradient)
        return poisoned_data.detach()

    def generate_label_flipping_attacks(self, labels: np.ndarray) -> np.ndarray:
        """Generate label flipping attacks"""
        flipped_labels = labels.copy()
        flip_mask = np.random.random(len(labels)) < self.flip_rate
        flipped_labels[flip_mask] = 1 - labels[flip_mask]
        return flipped_labels

    def generate_backdoor_triggers(self, data: np.ndarray, trigger_pattern: np.ndarray) -> np.ndarray:
        """Generate backdoor triggers"""
        poisoned_data = data.copy()
        backdoor_mask = np.random.random(len(data)) < self.flip_rate
        poisoned_data[backdoor_mask] += trigger_pattern
        return poisoned_data

    def generate_clean_label_poisoning(self, data: np.ndarray, boundary_shift: np.ndarray) -> np.ndarray:
        """Generate clean label poisoning"""
        poisoned_data = data.copy()
        poison_mask = np.random.random(len(data)) < self.flip_rate
        poisoned_data[poison_mask] += boundary_shift
        return poisoned_data


## DATASET LOADER

class EnhancedDatasetLoader:
    """Enhanced dataset loader with dataset-specific processing and poisoning detection"""
    def __init__(self, dataset_type: str, config: ModelConfig = None):
        self.dataset_type = dataset_type.lower()
        self.config = config or ModelConfig()

        # Label column mappings
        self.label_columns = {
            'cic': 'Label',
            'ton': 'label',
            'cse': ' Label'  # Note the space before Label for CSE dataset
        }

        # Initialize processors
        self.feature_scaler = StandardScaler()
        self.label_encoder = LabelEncoder()
        self.label_handler = LabelHandler()
        self.dataset_processor = DatasetSpecificProcessor(self.dataset_type)
        self.poisoning_extractor = PoisoningFeatureExtractor()
        self.poisoning_generator = PoisoningGenerator()
        self.data_stabilizer = DataStabilizer()

        # Add clip values
        self.max_value = 1e10
        self.min_value = -1e10

        # Validation thresholds
        self.validation_thresholds = {
            'mean_threshold': 0.1,
            'std_threshold': 0.1,
            'correlation_threshold': 0.95,
            'min_class_ratio': 0.01
        }

        # Statistics tracking
        self.stats = defaultdict(list)
        self.feature_columns = None
        self.removed_columns = set()

        print(f"\nInitialized Enhanced Dataset Loader for {self.dataset_type.upper()}")

    def _get_label_column(self, chunk: pd.DataFrame) -> str:
        """Get appropriate label column based on dataset type and available columns"""
        # First try the predefined mapping
        default_label = self.label_columns.get(self.dataset_type)
        if default_label in chunk.columns:
            return default_label

        # Try common variations
        common_labels = ['Label', 'label', ' Label', 'type', 'class']
        for label in common_labels:
            if label in chunk.columns:
                return label

        # If still not found, look for any column containing 'label' (case insensitive)
        label_cols = [col for col in chunk.columns if 'label' in col.lower()]
        if label_cols:
            return label_cols[0]

        raise ValueError(f"No label column found for {self.dataset_type} dataset")

    def _process_chunk(self, chunk: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]:
        """Process data chunk with appropriate label handling"""
        try:
            # Get label column
            label_col = self._get_label_column(chunk)
            print(f"Found label column: {label_col}")

            if label_col not in chunk.columns:
                raise ValueError(f"Label column '{label_col}' not found")

            # Convert numeric columns
            feature_cols = [col for col in chunk.columns if col != label_col]
            numeric_chunk = pd.DataFrame()

            for col in feature_cols:
                try:
                    numeric_chunk[col] = pd.to_numeric(chunk[col], errors='coerce')
                except Exception as e:
                    print(f"Warning: Could not convert column {col}: {str(e)}")
                    numeric_chunk[col] = 0

            # Handle missing values
            numeric_chunk = numeric_chunk.fillna(0)

            # Extract features and preprocess
            features = self._preprocess_features(numeric_chunk.values)
            labels = chunk[label_col].values

            # Verify data
            if features is None or len(features) == 0:
                raise ValueError("No valid features extracted from chunk")

            return features, labels

        except Exception as e:
            print(f"Error processing chunk: {str(e)}")
            return None, None


    def _preprocess_features(self, features: np.ndarray) -> np.ndarray:
        """Preprocess features to handle infinities and large values"""
        try:
            if features is None or len(features) == 0:
                return None

            # Replace inf values
            features = np.nan_to_num(features, nan=0.0, posinf=self.max_value, neginf=self.min_value)

            # Clip extreme values
            features = np.clip(features, self.min_value, self.max_value)

            # Check for invalid values
            if not np.all(np.isfinite(features)):
                print("Warning: Invalid values found after preprocessing")
                features = np.nan_to_num(features, nan=0.0)

            return features

        except Exception as e:
            print(f"Error preprocessing features: {str(e)}")
            return None

    def load_and_process_dataset(self, file_path: str) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Dict]:
        """Load and process dataset with enhanced error handling"""
        try:
            print(f"\nProcessing {self.dataset_type.upper()} dataset: {file_path}")

            if not os.path.exists(file_path):
                raise FileNotFoundError(f"Dataset file not found: {file_path}")

            chunks = []
            labels = []
            poisoning_features = defaultdict(list)

            # Read and process chunks
            for chunk in pd.read_csv(file_path, chunksize=10000):
                features, chunk_labels = self._process_chunk(chunk)

                if features is not None and len(features) > 0:
                    try:
                        processed_chunk = self.dataset_processor.process_features(pd.DataFrame(features))
                        processed_features = self._preprocess_features(processed_chunk.values)

                        if processed_features is not None and len(processed_features) > 0:
                            chunks.append(processed_features)
                            labels.extend(chunk_labels)

                            try:
                                poison_features = self.poisoning_extractor.extract_poisoning_features(processed_chunk)
                                for key, value in poison_features.items():
                                    if value is not None and len(value) > 0:
                                        poisoning_features[key].append(value)
                            except Exception as e:
                                print(f"Warning: Error extracting poisoning features: {str(e)}")

                    except Exception as e:
                        print(f"Warning: Error processing chunk features: {str(e)}")

            if not chunks:
                raise ValueError("No valid data chunks processed")

            # Combine and process data
            X = np.vstack(chunks)
            y = np.array(labels)

            # Scale features
            X_scaled = self.feature_scaler.fit_transform(X)

            # Process labels
            y_encoded = self.label_encoder.fit_transform(y)
            binary_labels, multi_labels = self.label_handler.process_labels(y_encoded)

            # Combine poisoning features
            combined_poison_features = {}
            for key, values in poisoning_features.items():
                if values:
                    try:
                        combined_poison_features[key] = np.concatenate(values)
                    except Exception as e:
                        print(f"Warning: Could not combine poisoning features for {key}: {str(e)}")
                        combined_poison_features[key] = np.zeros(len(X))

            # Generate synthetic poisoning samples
            poisoned_samples = self._generate_poisoning_samples(X_scaled, binary_labels)

            print("\nDataset Processing Complete:")
            print(f"- Total samples: {len(y)}")
            print(f"- Feature dimensions: {X_scaled.shape[1]}")
            print(f"- Poisoning features extracted: {list(combined_poison_features.keys())}")

            return X_scaled, binary_labels, multi_labels, {
                'poisoning_features': combined_poison_features,
                'poisoned_samples': poisoned_samples,
                'validation_stats': self._validate_data_quality(X_scaled, binary_labels, multi_labels)
            }

        except Exception as e:
            print(f"Error processing dataset: {str(e)}")
            traceback.print_exc()
            raise


    def _validate_data_quality(self, X: np.ndarray, binary_labels: np.ndarray, multi_labels: np.ndarray) -> Dict:
        """Validate data quality and compute statistics"""
        validation_stats = {
            'feature_stats': {
                'mean': np.mean(X, axis=0),
                'std': np.std(X, axis=0)
            },
            'class_distribution': {
                'binary': np.bincount(binary_labels),
                'multi': np.bincount(multi_labels)
            },
            'missing_values': np.isnan(X).sum(),
            'feature_correlations': np.corrcoef(X.T)
        }
        return validation_stats

    def _generate_poisoning_samples(self, X: np.ndarray, y: np.ndarray) -> Dict[str, np.ndarray]:
        """Generate synthetic poisoning samples"""
        X_tensor = torch.FloatTensor(X)

        def dummy_loss(x):
            return torch.mean(x ** 2)

        poisoned_samples = {
            'gradient_based': self.poisoning_generator.generate_gradient_based_poisoning(
                X_tensor, dummy_loss
            ).numpy(),
            'label_flipping': self.poisoning_generator.generate_label_flipping_attacks(y),
            'backdoor': self.poisoning_generator.generate_backdoor_triggers(
                X, np.random.normal(0, 0.1, X.shape[1])
            ),
            'clean_label': self.poisoning_generator.generate_clean_label_poisoning(
                X, np.random.normal(0, 0.1, X.shape[1])
            )
        }

        return poisoned_samples


class DatasetStatistics:
    """Track and analyze dataset statistics"""
    def __init__(self):
        self.stats = defaultdict(list)

    def update(self, batch_stats: Dict):
        """Update statistics with batch information"""
        for k, v in batch_stats.items():
            self.stats[k].append(v)

    def get_summary(self) -> Dict:
        """Get summary statistics"""
        summary = {}
        for k, v in self.stats.items():
            if isinstance(v[0], (int, float, np.number)):
                summary[k] = {
                    'mean': np.mean(v),
                    'std': np.std(v),
                    'min': np.min(v),
                    'max': np.max(v)
                }
        return summary

class BatchGenerator:
    """Generate training batches with augmentation"""
    def __init__(self, X: np.ndarray, y: np.ndarray, batch_size: int,
                 shuffle: bool = True):
        self.X = X
        self.y = y
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.n_samples = len(X)
        self.indices = np.arange(self.n_samples)

    def __len__(self):
        return int(np.ceil(self.n_samples / self.batch_size))

    def __iter__(self):
        if self.shuffle:
            np.random.shuffle(self.indices)

        for start_idx in range(0, self.n_samples, self.batch_size):
            end_idx = min(start_idx + self.batch_size, self.n_samples)
            batch_indices = self.indices[start_idx:end_idx]

            yield (
                self.X[batch_indices],
                self.y[batch_indices]
            )



## neural network Component

In [None]:
# Feature Extraction
class FeatureExtractor(nn.Module):
    """Enhanced feature extraction with attention"""
    def __init__(self, input_dim: int, hidden_dims: List[int], dropout_rate: float = 0.2):
        super().__init__()
        self.input_dim = input_dim
        layers = []
        prev_dim = input_dim

        for dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, dim),
                nn.LayerNorm(dim),
                nn.ReLU(),
                nn.Dropout(dropout_rate)
            ])
            prev_dim = dim

        self.feature_layers = nn.Sequential(*layers)
        print(f"Feature extractor created with input dim: {input_dim}")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.dim() == 1:
            x = x.unsqueeze(0)
        elif x.dim() > 2:
            x = x.view(x.size(0), -1)

        return self.feature_layers(x)

class DataStabilizer:
    """Handle numerical stability in data processing"""
    def __init__(self, eps=1e-8):
        self.eps = eps

    def stabilize_array(self, arr):
        """Stabilize numpy array by handling zeros and infinities"""
        # Replace infinities with large finite numbers
        arr = np.nan_to_num(arr, nan=0.0, posinf=1e10, neginf=-1e10)
        return arr

    def safe_divide(self, numerator, denominator):
        """Safe division avoiding divide by zero"""
        return numerator / (denominator + self.eps)

    def normalize_features(self, features):
        """Normalize features with numerical stability"""
        mean = np.mean(features, axis=0)
        std = np.std(features, axis=0) + self.eps
        return (features - mean) / std

    def stabilize_gradients(self, tensor):
        """Stabilize gradients for tensor operations"""
        if torch.is_tensor(tensor):
            return torch.clamp(tensor, min=-1e6, max=1e6)
        return tensor



# DQN Component
class DQNStream(nn.Module):
    """DQN stream with LSTM and attention"""
    def __init__(self, feature_dim: int, hidden_dim: int, num_actions: int,
                 num_heads: int = 4):
        super().__init__()
        self.lstm = nn.LSTM(feature_dim, hidden_dim, batch_first=True)
        self.attention = nn.MultiheadAttention(hidden_dim, num_heads)

        self.value_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_actions)
        )

        self.advantage_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_actions)
        )

    def forward(self, features: torch.Tensor, hidden=None) -> Tuple[torch.Tensor, Tuple]:
        # LSTM processing
        lstm_out, hidden = self.lstm(features, hidden)

        # Self-attention
        attn_out, _ = self.attention(lstm_out, lstm_out, lstm_out)

        # Combine LSTM and attention
        combined = lstm_out + attn_out
        last_hidden = combined[:, -1, :]

        # Dueling DQN architecture
        values = self.value_head(last_hidden)
        advantages = self.advantage_head(last_hidden)

        # Combine value and advantage
        q_values = values + (advantages - advantages.mean(dim=1, keepdim=True))

        return q_values, hidden

# Actor-Critic Component
class ActorCriticStream(nn.Module):
    """Actor-Critic stream with shared features"""
    def __init__(self, feature_dim: int, hidden_dim: int, num_actions: int):
        super().__init__()

        # Shared layers
        self.shared_layer = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim),
            nn.ReLU(),
            nn.LayerNorm(hidden_dim)
        )

        # Actor (policy) network
        self.actor = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_actions),
            nn.Softmax(dim=-1)
        )

        # Critic (value) network
        self.critic = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        shared_features = self.shared_layer(features)

        # Get policy distribution and state value
        action_probs = self.actor(shared_features)
        state_value = self.critic(shared_features)

        return action_probs, state_value

# Combined Architecture
class DualStreamDetector(nn.Module):
    """Integrated DQN and Actor-Critic architecture"""
    def __init__(self, input_dim: int, feature_dim: int, hidden_dim: int,
                 num_actions: int, num_heads: int = 4):
        super().__init__()

        # Components
        self.feature_extractor = FeatureExtractor(input_dim, [hidden_dim, feature_dim])
        self.dqn_stream = DQNStream(feature_dim, hidden_dim, num_actions, num_heads)
        self.ac_stream = ActorCriticStream(feature_dim, hidden_dim, num_actions)

        # Integration layer
        self.fusion_layer = nn.Sequential(
            nn.Linear(num_actions * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_actions)
        )

    def forward(self, x: torch.Tensor, hidden=None) -> Dict[str, torch.Tensor]:
        # Extract features
        features = self.feature_extractor(x)

        # DQN stream
        q_values, new_hidden = self.dqn_stream(features.unsqueeze(1), hidden)

        # Actor-Critic stream
        action_probs, state_value = self.ac_stream(features)

        # Combine outputs
        combined = torch.cat([q_values, action_probs], dim=-1)
        final_output = self.fusion_layer(combined)

        return {
            'q_values': q_values,
            'action_probs': action_probs,
            'state_value': state_value,
            'final_output': final_output,
            'hidden': new_hidden
        }



## Poisoning Detection System (Base and Enhanced)

In [None]:
class DynamicThresholdManager:
    """Manages dynamic thresholds for poisoning detection"""
    def __init__(self, initial_threshold=0.5, adaptation_rate=0.01):
        self.threshold = initial_threshold
        self.adaptation_rate = adaptation_rate
        self.historical_predictions = deque(maxlen=1000)
        self.confidence_history = deque(maxlen=1000)

    def update_threshold(self, current_confidence: float, prediction_correct: bool):
        self.historical_predictions.append(prediction_correct)
        self.confidence_history.append(current_confidence)

        recent_accuracy = np.mean(self.historical_predictions)
        confidence_variance = np.std(self.confidence_history)

        if recent_accuracy < 0.9:
            self.threshold += self.adaptation_rate * (1 - recent_accuracy)
        else:
            self.threshold -= self.adaptation_rate * confidence_variance

        self.threshold = np.clip(self.threshold, 0.3, 0.9)
        return self.threshold

## GRADUAL POISONING DETECTOR

def sigmoid(x):
    """Numpy implementation of sigmoid function"""
    return 1 / (1 + np.exp(-x))

class GradualPoisoningDetector:
    def __init__(self, window_size=100):  # Add this
        self.window_size = window_size
        self.feature_history = deque(maxlen=window_size)
        self.distribution_history = deque(maxlen=window_size)

    def analyze_gradual_changes(self, current_features: np.ndarray) -> Dict[str, float]:
        """Analyze gradual changes with numerical stability"""
        try:
            self.feature_history.append(current_features)

            if len(self.feature_history) < 2:
                return {
                    'gradual_poison_probability': 0.0,
                    'change_consistency': 0.0,
                    'change_trend': 0.0
                }

            # Calculate distribution with eps for numerical stability
            eps = 1e-8
            current_dist = np.histogram(current_features, bins=20)[0] + eps
            self.distribution_history.append(current_dist)

            if len(self.distribution_history) >= 2:
                # Use stable calculation methods
                distribution_changes = np.diff([dist for dist in self.distribution_history], axis=0)

                # Add small epsilon to avoid division by zero
                abs_changes = np.abs(distribution_changes) + eps
                mean_change = np.mean(abs_changes)
                std_change = np.std(abs_changes) + eps

                # Calculate gradual score with safety checks
                gradual_score = np.clip(
                    mean_change * std_change * len(self.distribution_history),
                    -100, 100
                )

                # Calculate trend safely
                time_points = np.arange(len(distribution_changes))
                if len(time_points) > 1:
                    try:
                        avg_changes = np.mean(distribution_changes, axis=1)
                        trend = np.polyfit(time_points, avg_changes, 1)[0]
                    except:
                        trend = 0.0
                else:
                    trend = 0.0

                return {
                    'gradual_poison_probability': float(sigmoid(gradual_score)),
                    'change_consistency': float(mean_change),
                    'change_trend': float(trend)
                }

            return {
                'gradual_poison_probability': 0.0,
                'change_consistency': 0.0,
                'change_trend': 0.0
            }

        except Exception as e:
            print(f"Warning: Error in gradual change analysis: {str(e)}")
            return {
                'gradual_poison_probability': 0.0,
                'change_consistency': 0.0,
                'change_trend': 0.0
            }


    def reset(self):
        """Reset detector state"""
        self.feature_history.clear()
        self.distribution_history.clear()

    def get_window_stats(self) -> Dict[str, float]:
        """Get statistics about the current detection window"""
        return {
            'window_size': len(self.feature_history),
            'max_window': self.window_size,
            'distribution_samples': len(self.distribution_history)
        }


## DATA AUGUMENTATION

class DataAugmentation:
    """Data augmentation techniques for poisoning detection"""
    def __init__(self, noise_std=0.01, swap_prob=0.1):
        self.noise_std = noise_std
        self.swap_prob = swap_prob

    def augment(self, data: torch.Tensor) -> torch.Tensor:
        augmented = data.clone()

        # Add Gaussian noise
        if random.random() < self.swap_prob:
            noise = torch.randn_like(augmented) * self.noise_std
            augmented += noise

        # Feature permutation
        if random.random() < self.swap_prob:
            idx = torch.randperm(augmented.size(1))
            augmented = augmented[:, idx]

        return augmented


# Metrics Tracking
class PoisoningDetectionMetrics:
    """Track detection metrics and performance"""
    def __init__(self, label_handler=None):  # Make label_handler optional
        self.detection_history = []
        self.distribution_stats = []
        self.dqn_metrics = defaultdict(list)
        self.ac_metrics = defaultdict(list)
        self.label_handler = label_handler  # Store label_handler

    def update_metrics(self, features: np.ndarray, prediction: float,
                      true_label: int, dqn_values: np.ndarray = None,
                      ac_probs: np.ndarray = None):
        """Update all metrics"""
        # Store detection results
        self.detection_history.append({
            'prediction': prediction,
            'true_label': true_label,
            'feature_stats': {
                'mean': np.mean(features, axis=0),
                'std': np.std(features, axis=0),
                'kurtosis': scipy.stats.kurtosis(features, axis=0),
                'skewness': scipy.stats.skew(features, axis=0)
            }
        })

        # Track DQN metrics
        if dqn_values is not None:
            self.dqn_metrics['q_values'].append(np.mean(dqn_values))
            self.dqn_metrics['q_std'].append(np.std(dqn_values))

        # Track Actor-Critic metrics
        if ac_probs is not None:
            self.ac_metrics['policy_entropy'].append(
                -np.sum(ac_probs * np.log(ac_probs + 1e-10))
            )

    def compute_metrics(self) -> Dict[str, float]:
        """Compute comprehensive metrics"""
        if not self.detection_history:
            return {}

        # Extract predictions and true labels
        predictions = [d['prediction'] for d in self.detection_history]
        true_labels = [d['true_label'] for d in self.detection_history]

        # Convert predictions to class indices if they're probabilities
        pred_indices = []
        for pred in predictions:
            if isinstance(pred, np.ndarray) and pred.ndim > 0:
                # If prediction is a probability array, get the argmax
                pred_indices.append(np.argmax(pred))
            else:
                # If prediction is already a single value
                pred_indices.append(pred)

        metrics = {
            'accuracy': np.mean([p == t for p, t in zip(pred_indices, true_labels)]),
            'detection_confidence': np.mean([
                p.max() if isinstance(p, np.ndarray) and p.ndim > 0 else p
                for p in predictions
            ]),
            'false_positive_rate': self._compute_fpr(pred_indices, true_labels)
        }

        # Add DQN metrics
        if self.dqn_metrics:
            metrics.update({
                'avg_q_value': np.mean(self.dqn_metrics['q_values']),
                'q_value_std': np.mean(self.dqn_metrics['q_std'])
            })

        # Add Actor-Critic metrics
        if self.ac_metrics:
            metrics.update({
                'policy_entropy': np.mean(self.ac_metrics['policy_entropy'])
            })

        return metrics

    def _compute_fpr(self, predictions: List[float], true_labels: List[int]) -> float:
        """Compute False Positive Rate"""
        fp = sum(1 for p, t in zip(predictions, true_labels) if p == 1 and t == 0)
        tn = sum(1 for p, t in zip(predictions, true_labels) if p == 0 and t == 0)
        return fp / (fp + tn) if (fp + tn) > 0 else 0.0



# Base Detection System
class PoisoningDetectionSystem:
    """Base class for poisoning detection"""
    def __init__(self, input_dim: int, config: ModelConfig = None):
        if config is None:
            config = ModelConfig()

        self.config = config
        self.device = config.device

        # Initialize model
        self.model = DualStreamDetector(
            input_dim=input_dim,
            feature_dim=config.feature_dim,
            hidden_dim=config.hidden_dim,
            num_actions=config.num_actions
        ).to(self.device)

        # Optimizer and AMP scaler
        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=config.learning_rate
        )
        self.scaler = torch.cuda.amp.GradScaler() if config.use_amp and self.device.type == "cuda" else None

        # Experience replay
        self.replay_buffer = deque(maxlen=config.replay_buffer_size)
        self.sequence_length = config.sequence_length

        # Metrics
        self.training_metrics = defaultdict(list)

    def preprocess_state(self, state: np.ndarray) -> torch.Tensor:
        """Preprocess state for model input"""
        try:
            if isinstance(state, np.ndarray):
                if state.ndim == 1:
                    state = state.reshape(1, -1)
                state_tensor = torch.from_numpy(state).float()
            else:
                state_tensor = state.float()
                if state_tensor.dim() == 1:
                    state_tensor = state_tensor.unsqueeze(0)

            return state_tensor.to(self.device)

        except Exception as e:
            print(f"Error in preprocess_state: {str(e)}")
            raise

    def detect(self, state: np.ndarray, evaluate: bool = False) -> Dict[str, np.ndarray]:
        """Detect poisoning attacks"""
        with torch.no_grad() if evaluate else torch.enable_grad():
            processed_state = self.preprocess_state(state)
            model_output = self.model(processed_state)

            # Get predictions
            q_values = model_output['q_values']
            action_probs = model_output['action_probs']
            detection_prob = F.softmax(model_output['final_output'], dim=-1)

            # Convert to numpy
            return {
                'is_poisoning': detection_prob.cpu().numpy(),
                'q_values': q_values.cpu().numpy(),
                'action_probs': action_probs.cpu().numpy(),
                'confidence': model_output['state_value'].cpu().numpy()
            }

class PoisoningLoss(nn.Module):
    def __init__(self, dqn_weight=0.4, policy_weight=0.3, value_weight=0.3, eps=1e-8):
        super().__init__()
        self.dqn_weight = dqn_weight
        self.policy_weight = policy_weight
        self.value_weight = value_weight
        self.eps = eps

        self.dqn_criterion = nn.SmoothL1Loss()
        self.policy_criterion = nn.CrossEntropyLoss()
        self.value_criterion = nn.MSELoss()

    def forward(self, model_output, targets):
        # Add numerical stability to outputs
        q_values = model_output['q_values'].clamp(min=-100, max=100)
        action_probs = F.softmax(model_output['action_probs'], dim=-1)
        action_probs = torch.clamp(action_probs, min=self.eps, max=1.0)

        # DQN loss with gradient scaling
        dqn_loss = self.dqn_criterion(q_values, targets['q_targets'])
        dqn_loss = torch.where(torch.isfinite(dqn_loss), dqn_loss, torch.zeros_like(dqn_loss))

        # Policy loss with stable log
        policy_loss = self.policy_criterion(
            action_probs,
            targets['actions']
        )

        # Value loss with bounded predictions
        value_pred = model_output['state_value'].view(-1).clamp(min=-100, max=100)
        value_target = targets['returns'].view(-1).clamp(min=-100, max=100)
        value_loss = self.value_criterion(value_pred, value_target)

        # Detection loss with stable probabilities
        detection_probs = F.softmax(model_output['final_output'], dim=-1)
        detection_probs = torch.clamp(detection_probs, min=self.eps, max=1.0)
        detection_loss = self.policy_criterion(
            detection_probs,
            targets['labels']
        )

        # Combine losses with stability checks
        total_loss = (
            self.dqn_weight * torch.nan_to_num(dqn_loss) +
            self.policy_weight * torch.nan_to_num(policy_loss + detection_loss) +
            self.value_weight * torch.nan_to_num(value_loss)
        )

        return {
            'total_loss': total_loss,
            'dqn_loss': dqn_loss.item(),
            'policy_loss': policy_loss.item(),
            'value_loss': value_loss.item(),
            'detection_loss': detection_loss.item()
        }



class PoisoningDataAugmentation:
    """Advanced data augmentation techniques specifically for poisoning detection"""
    def __init__(self,
                 noise_std=0.01,
                 feature_swap_prob=0.1,
                 feature_scale_range=(0.95, 1.05),
                 temporal_shift_prob=0.1,
                 max_shift=3):
        self.noise_std = noise_std
        self.feature_swap_prob = feature_swap_prob
        self.feature_scale_range = feature_scale_range
        self.temporal_shift_prob = temporal_shift_prob
        self.max_shift = max_shift

    def augment(self, data: torch.Tensor) -> torch.Tensor:
        """
        Apply various augmentation techniques to the input data

        Args:
            data: Input tensor of shape (batch_size, feature_dim) or (batch_size, sequence_length, feature_dim)

        Returns:
            Augmented tensor of the same shape
        """
        augmented = data.clone()

        # Add Gaussian noise
        if random.random() < self.feature_swap_prob:
            noise = torch.randn_like(augmented) * self.noise_std
            augmented += noise

        # Random feature scaling
        if random.random() < self.feature_swap_prob:
            scale_factors = torch.FloatTensor(augmented.shape[-1]).uniform_(*self.feature_scale_range)
            if augmented.dim() == 3:  # Sequential data
                scale_factors = scale_factors.unsqueeze(0).unsqueeze(0)
            else:  # Single timestep data
                scale_factors = scale_factors.unsqueeze(0)
            augmented *= scale_factors.to(augmented.device)

        # Feature permutation
        if random.random() < self.feature_swap_prob:
            feat_idx = torch.randperm(augmented.shape[-1])
            if augmented.dim() == 3:
                augmented = augmented[:, :, feat_idx]
            else:
                augmented = augmented[:, feat_idx]

        # Temporal shift for sequential data
        if augmented.dim() == 3 and random.random() < self.temporal_shift_prob:
            shift = random.randint(-self.max_shift, self.max_shift)
            augmented = torch.roll(augmented, shifts=shift, dims=1)

        # Ensure values stay within reasonable bounds
        augmented = torch.clamp(augmented, min=-10, max=10)

        return augmented

    def augment_batch(self, data: torch.Tensor, labels: torch.Tensor = None) -> tuple:
        """
        Augment a batch of data with optional label preservation

        Args:
            data: Input tensor
            labels: Optional label tensor

        Returns:
            Tuple of (augmented_data, labels)
        """
        augmented_data = self.augment(data)

        if labels is not None:
            return augmented_data, labels
        return augmented_data

    @staticmethod
    def mix_samples(data: torch.Tensor, labels: torch.Tensor, alpha: float = 0.2) -> tuple:
        """
        Implement mixup augmentation for robust learning

        Args:
            data: Input tensor
            labels: Label tensor
            alpha: Mixup interpolation strength

        Returns:
            Tuple of (mixed_data, mixed_labels)
        """
        if alpha > 0:
            lam = np.random.beta(alpha, alpha)
        else:
            lam = 1

        batch_size = data.size(0)
        index = torch.randperm(batch_size).to(data.device)

        mixed_data = lam * data + (1 - lam) * data[index]
        mixed_labels = lam * labels + (1 - lam) * labels[index]

        return mixed_data, mixed_labels


# ENHANCED DETECTION SYSTEM

class EnhancedPoisoningDetectionSystem(PoisoningDetectionSystem):
    def __init__(self, input_dim: int, config: ModelConfig = None, label_handler: LabelHandler = None):
        super().__init__(input_dim, config)

        # Initialize device
        self.device = config.device if config is not None else torch.device("cpu")

        # Move model to correct device immediately after creation
        self.model = self.model.to(self.device)

        # Initialize components
        self.label_handler = label_handler
        self.metrics_tracker = PoisoningDetectionMetrics(label_handler)
        self.threshold_manager = DynamicThresholdManager()
        self.gradual_detector = GradualPoisoningDetector()
        self.pattern_memory = deque(maxlen=1000)

        # Add missing components that caused errors
        self.data_stabilizer = DataStabilizer()
        self.criterion = PoisoningLoss()  # Initialize loss function
        self.data_augmentation = PoisoningDataAugmentation()  # Initialize data augmentation

        # Modify optimizer
        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=1e-4,
            eps=1e-8  # Increased epsilon for optimizer stability
        )

        # Print initialization info
        print("\nEnhanced Detection System Initialized:")
        print(f"- Input dimension: {input_dim}")
        print(f"- Number of attack types: {len(label_handler.attack_types) if label_handler else 'N/A'}")
        print(f"- Binary classification: Normal vs Attack")
        print(f"- Using label handler: {label_handler is not None}")
        print(f"- Device: {self.device}")

        if self.device.type == "cuda":
            print(f"- GPU Memory Usage: {torch.cuda.memory_allocated()/1e9:.2f}GB")



    # Add this method to EnhancedPoisoningDetectionSystem
    def preprocess_batch(self, states, binary_labels, multi_labels):
        """Preprocess batch data with stability checks"""
        # Stabilize states
        states = self.data_stabilizer.stabilize_array(states)

        # Convert to tensors with proper dtype
        states_tensor = torch.FloatTensor(states).to(self.device)
        binary_labels_tensor = torch.LongTensor(binary_labels).to(self.device)

        # Gradient scaling for large values
        if states_tensor.abs().max() > 1e3:
            states_tensor = F.normalize(states_tensor, dim=1)

        return states_tensor, binary_labels_tensor

        # Loss function
        self.criterion = PoisoningLoss()

        # Data augmentation
        self.data_augmentation = PoisoningDataAugmentation()

        # Print initialization info
        print("\nEnhanced Detection System Initialized:")
        print(f"- Input dimension: {input_dim}")
        print(f"- Number of attack types: {len(label_handler.attack_types) if label_handler else 'N/A'}")
        print(f"- Binary classification: Normal vs Attack")
        print(f"- Using label handler: {label_handler is not None}")
        print(f"- Device: {self.device}")

        if self.device.type == "cuda":
            print(f"- GPU Memory Usage: {torch.cuda.memory_allocated()/1e9:.2f}GB")


    def _check_label_handler(self):
        """Verify label handler is available when needed"""
        if self.label_handler is None:
            raise ValueError("Label handler is required for multi-class attack analysis")
        return True

    def get_attack_info(self, attack_id: int) -> Dict:
        """Get information about a specific attack type"""
        self._check_label_handler()
        return self.label_handler.get_attack_info(attack_id)


    def detect(self, state: np.ndarray, evaluate: bool = False,
              labels: np.ndarray = None) -> Dict[str, np.ndarray]:
        """Enhanced detection with multiple analysis streams"""
        with torch.no_grad() if evaluate else torch.enable_grad():

            with torch.amp.autocast('cuda') if self.scaler else contextlib.nullcontext():
                # Process state
                processed_state = self.preprocess_state(state)
                if not evaluate and random.random() < 0.3:
                    processed_state = self.data_augmentation.augment(processed_state)

                # Model forward pass
                model_output = self.model(processed_state)
                combined_output = model_output['final_output']
                detection_prob = F.softmax(combined_output, dim=-1)

                # Get numpy state for gradual analysis
                numpy_state = state.cpu().numpy() if torch.is_tensor(state) else state

                try:
                    # Analyze gradual changes with error handling
                    gradual_analysis = self.gradual_detector.analyze_gradual_changes(numpy_state)
                except Exception as e:
                    print(f"Warning: Error in gradual analysis: {str(e)}")
                    gradual_analysis = {
                        'gradual_poison_probability': 0.0,
                        'change_consistency': 0.0,
                        'change_trend': 0.0
                    }

                # Update threshold
                current_threshold = self.threshold_manager.update_threshold(
                    detection_prob.detach().mean().item(),
                    detection_prob.detach().argmax().item() == 1
                )

                # Combine detections
                enhanced_prob = (
                    detection_prob.detach().cpu().numpy() * self.config.dqn_weight +
                    gradual_analysis['gradual_poison_probability'] * self.config.ac_weight
                )

                # Update pattern memory
                self.pattern_memory.append({
                    'features': state.detach().cpu().numpy() if torch.is_tensor(state) else state,
                    'basic_detection': detection_prob.detach().cpu().numpy(),
                    'gradual_score': gradual_analysis['gradual_poison_probability']
                })

                sequence_analysis = self._analyze_sequential_patterns()

                return {
                    'is_poisoning': enhanced_prob,
                    'q_values': model_output['q_values'].detach().cpu().numpy(),
                    'action_probs': model_output['action_probs'].detach().cpu().numpy(),
                    'confidence': model_output['state_value'].detach().cpu().numpy(),
                    'gradual_metrics': gradual_analysis,
                    'sequence_metrics': sequence_analysis,
                    'threshold': current_threshold,
                    'detection_metrics': self.metrics_tracker.compute_metrics()
                }

    def _analyze_sequential_patterns(self) -> Dict[str, float]:
        """Analyze temporal patterns in detection history"""
        try:
            if len(self.pattern_memory) < 2:
                return {
                    'sequence_score': 0.0,
                    'pattern_consistency': 0.0,
                    'temporal_correlation': 0.0
                }

            # Get recent patterns and ensure they're the same shape
            recent_patterns = list(self.pattern_memory)[-10:]
            features_list = []

            # Handle variable-sized features
            base_shape = None
            for pattern in recent_patterns:
                features = pattern['features']
                if isinstance(features, torch.Tensor):
                    features = features.cpu().numpy()

                # If this is first valid shape, use it as base
                if base_shape is None and features is not None:
                    base_shape = features.shape

                # Only include features matching base shape
                if base_shape is not None and features is not None and features.shape == base_shape:
                    features_list.append(features)

            # If we don't have enough valid patterns, return default values
            if len(features_list) < 2:
                return {
                    'sequence_score': 0.0,
                    'pattern_consistency': 0.0,
                    'temporal_correlation': 0.0
                }

            # Calculate metrics only on valid patterns
            try:
                # Convert to numpy array and calculate differences
                features_array = np.stack(features_list)
                feature_evolution = np.diff(features_array, axis=0)

                # Calculate pattern metrics
                pattern_consistency = np.mean(np.abs(feature_evolution), axis=0)

                # Get detection scores
                detection_scores = [
                    float(np.mean(p['basic_detection']))
                    for p in recent_patterns[-len(features_list):]
                ]

                gradual_scores = [
                    float(p['gradual_score'])
                    for p in recent_patterns[-len(features_list):]
                ]

                # Calculate correlation if we have enough samples
                if len(detection_scores) > 1:
                    temporal_correlation = np.corrcoef(
                        detection_scores,
                        gradual_scores
                    )[0, 1]
                    if np.isnan(temporal_correlation):
                        temporal_correlation = 0.0
                else:
                    temporal_correlation = 0.0

                sequence_score = float(sigmoid(temporal_correlation * np.mean(pattern_consistency)))

                return {
                    'sequence_score': sequence_score,
                    'pattern_consistency': float(np.mean(pattern_consistency)),
                    'temporal_correlation': float(temporal_correlation)
                }

            except Exception as e:
                print(f"Warning: Error in sequence analysis calculations: {str(e)}")
                return {
                    'sequence_score': 0.0,
                    'pattern_consistency': 0.0,
                    'temporal_correlation': 0.0
                }

        except Exception as e:
            print(f"Warning: Error in sequence pattern analysis: {str(e)}")
            return {
                'sequence_score': 0.0,
                'pattern_consistency': 0.0,
                'temporal_correlation': 0.0
            }


    def train(self, batch_size: int) -> Dict[str, float]:
        """Train the model using experiences from the replay buffer"""
        try:
            if len(self.replay_buffer) < batch_size:
                return {'status': 'insufficient_samples'}

            # Sample and prepare batch
            indices = np.random.choice(len(self.replay_buffer), batch_size, replace=False)
            batch = [self.replay_buffer[i] for i in indices]

            # Unpack and move to device
            states, binary_labels, multi_labels, predictions, next_states, dones = zip(*batch)

            states = torch.FloatTensor(np.array(states)).to(self.device)
            binary_labels = torch.LongTensor(np.array(binary_labels)).to(self.device)
            next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
            dones = torch.FloatTensor(np.array(dones)).to(self.device)


            # Validate tensor shapes
            if states.dim() != 2 or next_states.dim() != 2:
                raise ValueError(f"Invalid state tensor dimensions: states={states.shape}, next_states={next_states.shape}")
            if binary_labels.dim() != 1:
                raise ValueError(f"Invalid labels tensor dimension: {binary_labels.shape}")

        except Exception as e:
            print(f"Error converting to tensors: {str(e)}")
            print(f"States shape: {np.array(states).shape if isinstance(states, (list, np.ndarray)) else 'invalid'}")
            print(f"Labels shape: {np.array(binary_labels).shape if isinstance(binary_labels, (list, np.ndarray)) else 'invalid'}")
            return {'error': 'tensor_conversion', 'details': str(e)}

        # Zero gradients
        self.optimizer.zero_grad()

        # Compute outputs and loss
        try:
          with torch.amp.autocast('cuda') if self.scaler else contextlib.nullcontext():
            # Forward pass
              outputs = self.model(states)
              next_outputs = self.model(next_states)

                    # Prepare targets for DQN
              next_q_values = next_outputs['q_values'].detach()
              q_targets = outputs['q_values'].clone().detach()
              for i in range(batch_size):
                  if not dones[i]:
                     q_targets[i, binary_labels[i]] = self.config.gamma * next_q_values[i].max()

                    # Prepare targets for actor-critic
              returns = torch.zeros_like(outputs['state_value'])
              for i in range(batch_size):
                  returns[i] = outputs['state_value'][i] + \
                              (1 - dones[i]) * self.config.gamma * next_outputs['state_value'][i].detach()

                    # Compute loss
              targets = {
                  'q_targets': q_targets,
                  'actions': binary_labels,
                  'returns': returns,
                  'labels': binary_labels
              }

              loss_dict = self.criterion(outputs, targets)
              total_loss = loss_dict['total_loss']

                    # Check for invalid loss values
              if torch.isnan(total_loss) or torch.isinf(total_loss):
                raise ValueError(f"Invalid loss value: {total_loss.item()}")

        except Exception as e:
            print(f"Error in forward pass or loss computation: {str(e)}")
            print(f"Model outputs shape: {outputs['q_values'].shape if 'q_values' in outputs else 'invalid'}")
            return {'error': 'forward_pass', 'details': str(e)}

              # Backward pass with AMP if available
        try:
            if self.scaler:
              self.scaler.scale(total_loss).backward()
              self.scaler.step(self.optimizer)
              self.scaler.update()
            else:
                total_loss.backward()
                      # Clip gradients
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.gradient_clip)
                self.optimizer.step()
        except Exception as e:
                  print(f"Error in backward pass or optimization: {str(e)}")
                  return {'error': 'backward_pass', 'details': str(e)}

              # Return metrics
        try:
            metrics = {
                 'status': 'success',
                 'total_loss': total_loss.item(),
                 'dqn_loss': loss_dict['dqn_loss'],
                 'policy_loss': loss_dict['policy_loss'],
                 'value_loss': loss_dict['value_loss'],
                 'detection_loss': loss_dict['detection_loss']
            }
            return metrics

        except Exception as e:
            print(f"Error computing metrics: {str(e)}")
            return {'error': 'metrics_computation', 'details': str(e)}

        except Exception as e:
            print(f"Unexpected error during training step: {str(e)}")
            traceback.print_exc()  # Print full traceback for debugging
            return {
                  'error': 'unexpected',
                  'details': str(e),
                  'traceback': traceback.format_exc()
              }


    def _compute_td_error(self, state_values, next_state_values, rewards, dones):
        """Compute TD error for value function updates"""
        target_values = rewards + (1 - dones) * self.config.gamma * next_state_values
        td_error = target_values - state_values
        return td_error



# Multi - Dataset Generator

In [None]:
class MultiDatasetGenerator(Sequence):
    """Generate batches from multiple datasets with iteration control"""
    def __init__(self, datasets: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
                 batch_size=256, shuffle=True, num_workers=0):
        # Initialize dataset parameters
        self.datasets = datasets
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.num_workers = num_workers

        # Initialize indices for each dataset
        self.indices = {k: np.arange(len(v[0])) for k, v in datasets.items()}

        # Calculate weights for dataset sampling
        total_samples = sum(len(v[0]) for v in datasets.values())
        self.dataset_weights = {k: len(v[0])/total_samples for k, v in datasets.items()}

        # Initialize iteration control
        self.current_epoch_iterations = 0
        self.max_epoch_iterations = 1000  # Maximum iterations per epoch

        print(f"Initialized MultiDatasetGenerator with {len(datasets)} datasets")
        for k, v in datasets.items():
            print(f"- {k}: {len(v[0])} samples")

    def __iter__(self):
        """Reset iteration state and shuffle if needed"""
        self.current_epoch_iterations = 0
        if self.shuffle:
            for k in self.indices:
                np.random.shuffle(self.indices[k])
        return self

    def __next__(self):
        """Get next batch with iteration limit check"""
        if self.current_epoch_iterations >= self.max_epoch_iterations:
            raise StopIteration

        self.current_epoch_iterations += 1

        # Select dataset and get batch
        chosen_dataset = np.random.choice(
            list(self.datasets.keys()),
            p=list(self.dataset_weights.values())
        )

        X, binary_labels, multi_labels = self.datasets[chosen_dataset]
        indices = self.indices[chosen_dataset]

        # Calculate batch indices
        start_idx = (self.current_epoch_iterations * self.batch_size) % len(indices)
        batch_indices = indices[start_idx:start_idx + self.batch_size]

        # Return batch data
        return (
            X[batch_indices],
            binary_labels[batch_indices],
            multi_labels[batch_indices]
        )

    def __getitem__(self, index):
        """Get specific batch by index"""
        # Ensure index is within bounds
        if index >= self.max_epoch_iterations:
            raise IndexError("Batch index out of range")

        chosen_dataset = np.random.choice(
            list(self.datasets.keys()),
            p=list(self.dataset_weights.values())
        )

        X, binary_labels, multi_labels = self.datasets[chosen_dataset]
        indices = self.indices[chosen_dataset]

        start_idx = (index * self.batch_size) % len(indices)
        batch_indices = indices[start_idx:start_idx + self.batch_size]

        return (
            X[batch_indices],
            binary_labels[batch_indices],
            multi_labels[batch_indices]
        )

    def __len__(self):
        """Returns precise number of batches per epoch"""
        return min(
            self.max_epoch_iterations,
            int(np.ceil(sum(len(v[0]) for v in self.datasets.values()) / self.batch_size))
        )

    def reset(self):
        """Reset iteration state"""
        self.current_epoch_iterations = 0
        if self.shuffle:
            for k in self.indices:
                np.random.shuffle(self.indices[k])

    def get_progress(self):
        """Get training progress information"""
        return {
            'current_iteration': self.current_epoch_iterations,
            'max_iterations': self.max_epoch_iterations,
            'progress': self.current_epoch_iterations / self.max_epoch_iterations
        }


## Training Pipeline

In [None]:
## EARLY STOPPING

class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.001, restore_best_weights=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.best_weights = None
        self.best_score = None
        self.counter = 0
        self.early_stop = False
        self.min_improvement = 1e-4  # Add this
        self.max_epochs_without_improvement = 10  # Add this
        self.epochs_without_improvement = 0  # Add this

    def __call__(self, val_score: float, model=None) -> bool:
        if self.best_score is None:
            self.best_score = val_score
            if self.restore_best_weights and model is not None:
                self.best_weights = self._get_model_weights(model)
        elif val_score > self.best_score + self.min_delta:
            self.best_score = val_score
            self.epochs_without_improvement = 0
            if self.restore_best_weights and model is not None:
                self.best_weights = self._get_model_weights(model)
        else:
            self.epochs_without_improvement += 1

        if self.epochs_without_improvement >= self.max_epochs_without_improvement:
            self.early_stop = True
            return True

        if abs(val_score - self.best_score) < self.min_improvement:
            self.counter += 1
        else:
            self.counter = 0

        if self.counter >= self.patience:
            self.early_stop = True

        return self.early_stop


    def _get_model_weights(self, model) -> dict:
        """Get a deep copy of model weights"""
        return {
            name: param.cpu().clone().detach()
            for name, param in model.state_dict().items()
        }

    def restore_weights(self, model) -> None:
        """Restore model to best weights"""
        if self.restore_best_weights and self.best_weights is not None:
            model.load_state_dict(self.best_weights)

    def reset(self) -> None:
        """Reset early stopping state"""
        self.best_score = None
        self.counter = 0
        self.early_stop = False
        self.best_weights = None

    def get_best_score(self) -> float:
        """Return the best score achieved"""
        return self.best_score if self.best_score is not None else float('-inf')

    def is_best_epoch(self, val_score: float) -> bool:
        """Check if current epoch achieved best score"""
        return self.best_score is None or val_score > self.best_score + self.min_delta



## Single DataSet Trainer

In [None]:
## SINGLE DATASET TRAINER
class SingleDatasetTrainer:
    """Handles training process for a single dataset"""
    def __init__(self, config: ModelConfig, dataset_type: str):
        self.config = config
        self.dataset_type = dataset_type.lower()
        self.device = config.device

        # Initialize components
        print(f"\nInitializing loader for {self.dataset_type.upper()}")
        self.loader = EnhancedDatasetLoader(dataset_type=self.dataset_type, config=self.config)
        self.memory_monitor = MemoryMonitor()
        self.metrics_tracker = PoisoningDetectionMetrics()

        # Training setup
        self.writer = SummaryWriter(f'logs/{self.dataset_type}')
        self.checkpoint_dir = os.path.join(config.checkpoint_dir, self.dataset_type)
        os.makedirs(self.checkpoint_dir, exist_ok=True)

    def train_on_dataset(self, file_path: str):
        """Complete training process for a single dataset"""
        try:
            print(f"\nStarting training process for {self.dataset_type.upper()}")

            # Verify file exists
            if not os.path.exists(file_path):
                raise FileNotFoundError(f"Dataset file not found: {file_path}")

            # Load and process dataset
            X, binary_labels, multi_labels, additional_info = self.loader.load_and_process_dataset(file_path)

            print("\nSplitting data into train and validation sets...")
            # Split data with error handling
            try:
                # Ensure array types
                X = np.array(X, dtype=np.float32)
                binary_labels = np.array(binary_labels, dtype=np.int32)

                X_train, X_val, y_train, y_val = train_test_split(
                    X, binary_labels,
                    test_size=0.2,
                    random_state=42,
                    stratify=binary_labels  # Ensure balanced split
                )
                print(f"Train set size: {len(X_train)}, Validation set size: {len(X_val)}")

            except Exception as e:
                print(f"Error in data splitting: {str(e)}")
                raise

            # Initialize detection system
            detection_system = EnhancedPoisoningDetectionSystem(
                input_dim=X.shape[1],
                config=self.config,
                label_handler=self.loader.label_handler
            )
            print("\nInitialized detection system")

            # Create data generators
            train_generator = MultiDatasetGenerator(
                {self.dataset_type: (X_train, y_train, y_train)},
                batch_size=self.config.batch_size
            )

            val_generator = MultiDatasetGenerator(
                {self.dataset_type: (X_val, y_val, y_val)},
                batch_size=self.config.batch_size
            )
            print("Created data generators")

            # Setup training pipeline
            pipeline = ComprehensiveTrainingPipeline(
                detection_system=detection_system,
                data_generator=train_generator,
                val_generator=val_generator,
                config=self.config,
                label_handler=self.loader.label_handler
            )

            # Train model
            print(f"\nStarting training on {self.dataset_type.upper()}...")
            pipeline.train()

            # Save results
            self._save_results(detection_system, additional_info)

            return detection_system, self.metrics_tracker

        except Exception as e:
            print(f"Error in training process: {str(e)}")
            traceback.print_exc()
            raise


    def _save_results(self, detection_system: EnhancedPoisoningDetectionSystem, additional_info: Dict):
        """Save training results and model"""
        try:
            results_path = os.path.join(self.checkpoint_dir, f'{self.dataset_type}_results.pt')

            torch.save({
                'model_state': detection_system.model.state_dict(),
                'config': self.config,
                'metrics': self.metrics_tracker.get_summary(),
                'poisoning_features': additional_info['poisoning_features'],
                'validation_stats': additional_info['validation_stats']
            }, results_path)

            print(f"\nResults saved to {results_path}")

        except Exception as e:
            print(f"Error saving results: {str(e)}")
            raise




## Compregensive Training Pipeline

In [None]:
## COMPREHENSIVE TRAINING PIPELINE

class ComprehensiveTrainingPipeline:
    """Advanced training pipeline with monitoring and evaluation"""
    def __init__(self, detection_system: EnhancedPoisoningDetectionSystem,
                 data_generator: MultiDatasetGenerator,
                 val_generator: MultiDatasetGenerator,
                 config: ModelConfig,
                 label_handler: LabelHandler):
        # Core components (keep existing)
        self.detection_system = detection_system
        self.data_generator = data_generator
        self.val_generator = val_generator
        self.config = config
        self.label_handler = label_handler

        # Training parameters (keep existing)
        self.batch_size = config.batch_size
        self.num_epochs = config.num_epochs
        self.checkpoint_dir = config.checkpoint_dir
        self.log_dir = config.log_dir

        # Initialize trackers (keep existing)
        self._global_step = 0
        self._epoch = 0
        self.best_metrics = {
            'accuracy': 0,
            'f1_score': 0,
            'unknown_detection_rate': 0,
            'per_attack_f1': defaultdict(float)
        }

        # Setup components (new method)
        self._setup_components()

    def preprocess_batch(self, states, binary_labels, multi_labels):
        """Preprocess batch data with stability checks"""
        try:
            # Get data stabilizer from detection system
            data_stabilizer = self.detection_system.data_stabilizer

            # Stabilize states
            states = data_stabilizer.stabilize_array(states)

            # Convert to tensors with proper dtype
            states_tensor = torch.FloatTensor(states).to(self.detection_system.device)
            binary_labels_tensor = torch.LongTensor(binary_labels).to(self.detection_system.device)

            # Gradient scaling for large values
            if states_tensor.abs().max() > 1e3:
                states_tensor = F.normalize(states_tensor, dim=1)

            return states_tensor, binary_labels_tensor

        except Exception as e:
            print(f"Error in batch preprocessing: {str(e)}")
            print(f"States shape: {states.shape if isinstance(states, np.ndarray) else 'invalid'}")
            print(f"Labels shape: {binary_labels.shape if isinstance(binary_labels, np.ndarray) else 'invalid'}")
            raise



    def _setup_components(self):
        """Initialize training components while preserving existing functionality"""
        try:
            # Create directories
            os.makedirs(self.checkpoint_dir, exist_ok=True)
            os.makedirs(self.log_dir, exist_ok=True)

            # Initialize monitoring (preserve existing)
            self.writer = SummaryWriter(self.log_dir)
            self.memory_monitor = MemoryMonitor()
            self.early_stopping = EarlyStopping(
                patience=self.config.patience,
                min_delta=self.config.min_delta
            )

            # Initialize attack-specific monitoring
            self.attack_metrics = {
                attack_id: defaultdict(list)
                for attack_id in self.label_handler.attack_types
            }

            # Performance tracking (integrates with existing metrics)
            self.train_losses = []
            self.val_losses = []
            self.performance_history = defaultdict(list)

            # Loss function (preserve existing criterion)
            self.criterion = PoisoningLoss()

            print(f"\nPipeline initialized:")
            print(f"- Batch size: {self.batch_size}")
            print(f"- Epochs: {self.num_epochs}")
            print(f"- Checkpoints: {self.checkpoint_dir}")
            print(f"- Logs: {self.log_dir}")
            print(f"- Monitoring {len(self.label_handler.attack_types)} attack types")

        except Exception as e:
            print(f"Error setting up training components: {str(e)}")
            raise
            
    
    def train(self):
        """Execute training loop with TPU support"""
        try:
            device = self.detection_system.device
            print(f"Training on device: {device}")

            # TPU specific setup
            if 'xla' in str(device):
                # Wrap data loader for TPU
                train_loader = pl.ParallelLoader(
                    self.data_generator, [device]
                ).per_device_loader(device)

                if self.val_generator:
                    val_loader = pl.ParallelLoader(
                        self.val_generator, [device]
                    ).per_device_loader(device)
            else:
                train_loader = self.data_generator
                val_loader = self.val_generator

        """Execute training loop with comprehensive monitoring and proper stopping conditions"""

            # Add convergence tracking
            plateau_counter = 0
            last_loss = float('inf')
            min_loss_change = 1e-5
            plateau_patience = 5
            max_iterations = 1000  # Maximum iterations per epoch

            for epoch in range(self.num_epochs):
                print(f"\nEpoch {epoch + 1}/{self.num_epochs}")
                print("=" * 50)

                self._epoch = epoch
                self.detection_system.model.train()
                epoch_metrics = defaultdict(list)
                epoch_start_time = time.time()

                n_batches = len(self.data_generator)
                iteration_counter = 0

                with tqdm(total=n_batches, desc=f"Training") as pbar:
                    for batch_idx, (states, binary_labels, multi_labels) in enumerate(self.data_generator):
                        # Check iteration limit
                        if iteration_counter >= max_iterations:
                            print(f"\nReached maximum iterations ({max_iterations}) for epoch {epoch + 1}")
                            break

                        # Train batch
                        batch_metrics = self._train_batch(states, binary_labels, multi_labels)

                        if batch_metrics:
                            # Check for convergence
                            current_loss = batch_metrics.get('total_loss', float('inf'))
                            if abs(current_loss - last_loss) < min_loss_change:
                                plateau_counter += 1
                            else:
                                plateau_counter = 0
                            last_loss = current_loss

                            # Check plateau condition
                            if plateau_counter >= plateau_patience:
                                print(f"\nTraining converged (loss plateau reached)")
                                return

                            # Filter and update metrics
                            numeric_metrics = {
                                k: v for k, v in batch_metrics.items()
                                if isinstance(v, (int, float, np.number))
                            }

                            for k, v in numeric_metrics.items():
                                epoch_metrics[k].append(float(v))

                            # Update progress bar
                            avg_metrics = {
                                k: np.mean(v) for k, v in epoch_metrics.items()
                            }
                            pbar.set_postfix(avg_metrics)
                            pbar.update(1)

                        iteration_counter += 1

                        # Memory check
                        if self.memory_monitor.check_memory():
                            print("\nHigh memory usage detected, breaking epoch")
                            break

                # Epoch completion checks
                avg_loss = np.mean([m.get('total_loss', float('inf')) for m in epoch_metrics.values()])
                if avg_loss < 1e-4:  # Convergence threshold
                    print(f"\nTraining converged (loss threshold reached)")
                    return

                # Run validation and early stopping
                if self.val_generator and (epoch + 1) % 5 == 0:
                    val_metrics = self._evaluate(epoch)
                    if self.early_stopping(val_metrics.get('accuracy', 0), self.detection_system.model):
                        print("\nEarly stopping triggered")
                        return

        except KeyboardInterrupt:
            print("\nTraining interrupted - saving checkpoint...")
            self._save_checkpoint('interrupt')
        except Exception as e:
            print(f"\nError during training: {str(e)}")
            traceback.print_exc()
            self._save_checkpoint('error')
            raise
        if 'xla' in str(device):
            xm.mark_step()


    def _train_batch(self, states, binary_labels, multi_labels):
        """Train a single batch with improved error handling and stability"""
        try:
            # Use preprocessed batch data
            states_tensor, binary_labels_tensor = self.preprocess_batch(
                states, binary_labels, multi_labels
            )

            # Forward pass with stability
            detection_output = self.detection_system.detect(
                states_tensor, labels=binary_labels_tensor
            )

            # Store experience with safe type conversion
            for i in range(len(states)):
                try:
                    self.detection_system.replay_buffer.append((
                        states[i].astype(np.float32),  # Ensure float32
                        int(binary_labels[i]),         # Ensure int
                        int(multi_labels[i]),          # Ensure int
                        detection_output['is_poisoning'][i],
                        states[i].astype(np.float32),  # Ensure float32
                        True
                    ))
                except Exception as e:
                    print(f"Warning: Error storing experience {i}: {str(e)}")
                    continue

            # Train if enough samples
            if len(self.detection_system.replay_buffer) >= self.batch_size:
                batch_metrics = self.detection_system.train(self.batch_size)

                # Ensure all metrics are numeric
                numeric_metrics = {}
                for k, v in batch_metrics.items():
                    if isinstance(v, (int, float, np.number)):
                        numeric_metrics[k] = float(v)
                    else:
                        print(f"Warning: Non-numeric metric encountered: {k} = {v}")

                return numeric_metrics
            return None

        except Exception as e:
            print(f"Error in batch training:")
            print(f"States shape: {states.shape}")
            print(f"Binary Labels shape: {binary_labels.shape}")
            print(f"Multi Labels shape: {multi_labels.shape}")
            raise e


    def _update_best_metrics(self, val_metrics: Dict[str, float], epoch: int):
        """Update best metrics if current results are better"""
        try:
            for metric_name, value in val_metrics.items():
                if isinstance(value, (int, float)):
                    if metric_name not in self.best_metrics or value > self.best_metrics[metric_name]:
                        self.best_metrics[metric_name] = value
                        print(f"New best {metric_name}: {value:.4f}")
        except Exception as e:
            print(f"Warning: Error updating best metrics: {str(e)}")

    def _log_metrics(self, metrics: Dict[str, float]):
        """Log metrics to tensorboard and update history

        Args:
            metrics: Dictionary containing metric names and values
        """
        try:
            # Log to tensorboard
            for name, value in metrics.items():
                self.writer.add_scalar(
                    f'metrics/{name}',
                    value,
                    self._global_step
                )

                # Update history
                self.performance_history[name].append(value)

            # Update global step
            self._global_step += 1

        except Exception as e:
            print(f"Warning: Error logging metrics: {str(e)}")

    def _log_attack_metrics(self, attack_id: int, predictions: np.ndarray, labels: np.ndarray):
        """Log metrics for specific attack type"""
        try:
            attack_info = self.label_handler.get_attack_info(attack_id)
            # Convert predictions to class indices if needed
            if len(predictions.shape) > 1:
                pred_indices = predictions.argmax(axis=1)
            else:
                pred_indices = predictions

            metrics = {
                'precision': precision_score(labels, pred_indices, average='binary'),
                'recall': recall_score(labels, pred_indices, average='binary'),
                'f1': f1_score(labels, pred_indices, average='binary'),
                'accuracy': accuracy_score(labels, pred_indices)
            }

            # Log to tensorboard
            for name, value in metrics.items():
                self.writer.add_scalar(
                    f'attack_metrics/{attack_info["attack_name"]}/{name}',
                    value,
                    self._global_step
                )

                # Store in attack metrics
                self.attack_metrics[attack_id][name].append(value)

        except Exception as e:
            print(f"Warning: Error logging attack metrics: {str(e)}")

    def _log_batch_metrics(self, batch_metrics: Dict[str, float]):
        """Log batch-level training metrics"""
        try:
            # Add batch metrics to history
            for name, value in batch_metrics.items():
                if name == 'total_loss':
                    self.train_losses.append(value)
                self.writer.add_scalar(f'batch/{name}', value, self._global_step)

        except Exception as e:
            print(f"Warning: Error logging batch metrics: {str(e)}")

    def _log_epoch_metrics(self, epoch: int, metrics: Dict[str, float]):
        """Log epoch-level metrics"""
        try:
            # Compute epoch averages
            epoch_metrics = {}
            for name, values in metrics.items():
                if values:  # Check if list is not empty
                    epoch_metrics[f'epoch_{name}'] = np.mean(values)

            # Log to tensorboard
            for name, value in epoch_metrics.items():
                self.writer.add_scalar(f'epoch/{name}', value, epoch)

            return epoch_metrics

        except Exception as e:
            print(f"Warning: Error logging epoch metrics: {str(e)}")
            return {}

    def _evaluate(self, epoch: int) -> Dict[str, float]:
        """Evaluate model performance with per-attack metrics"""
        self.detection_system.model.eval()
        all_metrics = {}
        per_attack_metrics = defaultdict(list)

        try:
            # Get validation data
            eval_data, eval_binary_labels, eval_multi_labels = next(iter(self.val_generator))

            # Calculate overall metrics
            metrics = evaluate_model(
                self.detection_system,
                eval_data,
                eval_binary_labels
            )

            # Calculate per-attack metrics
            for attack_id in self.label_handler.attack_types:
                attack_mask = eval_multi_labels == attack_id
                if np.any(attack_mask):
                    attack_metrics = evaluate_model(
                        self.detection_system,
                        eval_data[attack_mask],
                        eval_binary_labels[attack_mask]
                    )
                    attack_name = self.label_handler.get_attack_info(attack_id)['attack_name']
                    per_attack_metrics[attack_name] = attack_metrics

            # Test unknown attack detection
            unknown_metrics = test_unknown_attack_detection(
                self.detection_system,
                eval_data
            )

            # Combine all metrics
            all_metrics.update(metrics)
            all_metrics.update(unknown_metrics)
            all_metrics['per_attack'] = per_attack_metrics

            # Log per-attack metrics to tensorboard
            self._log_per_attack_metrics(per_attack_metrics, epoch)

        except Exception as e:
            print(f"Error during evaluation: {e}")

        return all_metrics

    def _log_per_attack_metrics(self, per_attack_metrics: Dict, epoch: int):
        """Log per-attack metrics to tensorboard"""
        for attack_name, metrics in per_attack_metrics.items():
            for metric_name, value in metrics.items():
                self.writer.add_scalar(
                    f'per_attack/{attack_name}/{metric_name}',
                    value,
                    epoch
                )

    def _save_checkpoint(self, identifier: str):
        """Save model checkpoint with additional metrics"""
        path = os.path.join(
            self.checkpoint_dir,
            f'checkpoint_{identifier}.pt'
        )

        checkpoint = {
            'epoch': self._epoch,
            'model_state_dict': self.detection_system.model.state_dict(),
            'optimizer_state_dict': self.detection_system.optimizer.state_dict(),
            'metrics': self.best_metrics,
            'global_step': self._global_step,
            'label_mapping': self.label_handler.label_mapping,  # Save label information
            'attack_types': self.label_handler.attack_types
        }

        torch.save(checkpoint, path)
        print(f"\nCheckpoint saved: {path}")

    def _handle_oom_error(self):
        """Handle out of memory error"""
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        self.batch_size = max(32, self.batch_size // 2)
        print(f"\nReducing batch size to: {self.batch_size}")

    def _check_memory(self):
        """Check memory usage"""
        if torch.cuda.is_available():
            memory_used = torch.cuda.memory_allocated()/1e9
            if memory_used > 0.9 * torch.cuda.get_device_properties(0).total_memory/1e9:
                print(f"\nHigh GPU memory usage ({memory_used:.2f}GB)")
                self._handle_oom_error()

    def _log_training_progress(self, epoch: int, batch_idx: int, metrics: Dict):
        """Log training progress"""
        try:
            # Log to tensorboard
            step = epoch * len(self.data_generator) + batch_idx
            for name, value in metrics.items():
                if isinstance(value, (int, float, np.number)):
                    self.writer.add_scalar(f'training/{name}', value, step)

            # Update global step
            self._global_step = step

        except Exception as e:
            print(f"Warning: Error logging progress: {str(e)}")

    def _handle_epoch_completion(self, epoch: int, epoch_metrics: Dict):
        """Handle end of epoch procedures"""
        try:
            # Calculate average metrics
            avg_metrics = {}
            for k, v in epoch_metrics.items():
                if isinstance(v, (list, np.ndarray)) and len(v) > 0:
                    if all(isinstance(x, (int, float, np.number)) for x in v):
                        avg_metrics[k] = float(np.mean(v))

            # Print epoch summary
            print(f"\nEpoch {epoch + 1} Complete:")
            print("-" * 50)
            print("Training Metrics:")
            for name, value in avg_metrics.items():
                print(f"- {name}: {value:.4f}")

            # Log to tensorboard
            for name, value in avg_metrics.items():
                self.writer.add_scalar(f'epoch/{name}', value, epoch)

            # Validation if available
            if self.val_generator and (epoch + 1) % 5 == 0:
                print("\nRunning Validation...")
                val_metrics = self._evaluate(epoch)
                self._update_best_metrics(val_metrics, epoch)

            # Memory management
            if self.memory_monitor.check_memory():
                self._handle_oom_error()

            print("\nCurrent Best Metrics:")
            for metric, value in self.best_metrics.items():
                if isinstance(value, (int, float)):
                    print(f"- Best {metric}: {value:.4f}")

        except Exception as e:
            print(f"Warning: Error in epoch completion handling: {str(e)}")
            traceback.print_exc()


    def _check_convergence(self, current_loss: float, last_loss: float,
                          plateau_counter: int, min_loss_change: float) -> Tuple[int, bool]:
        """Check if training has converged"""
        if abs(current_loss - last_loss) < min_loss_change:
            plateau_counter += 1
        else:
            plateau_counter = 0

        return plateau_counter, plateau_counter >= self.config.patience

    def _should_stop_training(self, iteration_counter: int, max_iterations: int,
                            avg_loss: float, plateau_counter: int) -> bool:
        """Check if training should stop"""
        if iteration_counter >= max_iterations:
            print(f"\nReached maximum iterations ({max_iterations})")
            return True

        if avg_loss < 1e-4:
            print(f"\nReached minimum loss threshold")
            return True

        if plateau_counter >= self.config.patience:
            print(f"\nLoss plateau reached")
            return True

        return False


## Dataset Processor

In [None]:
# Initialize
config = ModelConfig()  # Will automatically detect and setup TPU

class DatasetProcessor:
    """Handles dataset processing and model training"""
    def __init__(self, config: ModelConfig):
        self.config = config
        self.all_models = {}
        self.all_metrics = {}

    def process_single_dataset(self, dataset_type: str, data_path: str):
        """Process a single dataset"""
        try:
            # Load dataset using EnhancedDatasetLoader
            loader = EnhancedDatasetLoader(dataset_type, self.config)
            X, binary_labels, multi_labels, additional_info = loader.load_and_process_dataset(data_path)
            print(f"\nDataset loaded - Shape: {X.shape}")

            # Initialize components
            metrics_tracker = PoisoningDetectionMetrics(loader.label_handler)
            print("Created metrics tracker")

            # Create train and validation generators
            split_idx = int(0.8 * len(X))
            X_train, X_val = X[:split_idx], X[split_idx:]
            binary_train, binary_val = binary_labels[:split_idx], binary_labels[split_idx:]
            multi_train, multi_val = multi_labels[:split_idx], multi_labels[split_idx:]

            train_generator = MultiDatasetGenerator(
                {dataset_type: (X_train, binary_train, multi_train)},
                batch_size=self.config.batch_size,
                shuffle=True,
                num_workers=self.config.num_workers
            )
            print("Created data generator")

            val_generator = MultiDatasetGenerator(
                {dataset_type: (X_val, binary_val, multi_val)},
                batch_size=self.config.batch_size,
                shuffle=False,
                num_workers=self.config.num_workers
            )
            print("Created validation generator")

            # Initialize detection system
            detection_system = EnhancedPoisoningDetectionSystem(
                input_dim=X.shape[1],
                config=self.config,
                label_handler=loader.label_handler
            )
            print("Initialized detection system")

            # Setup training pipeline
            pipeline = ComprehensiveTrainingPipeline(
                detection_system=detection_system,
                data_generator=train_generator,
                val_generator=val_generator,
                config=self.config,
                label_handler=loader.label_handler
            )

            # Train
            print("\nStarting training...")
            pipeline.train()

            # Save additional info from enhanced processing
            self._save_additional_info(dataset_type, additional_info)

            return detection_system, metrics_tracker

        except Exception as e:
            print(f"Error processing dataset: {str(e)}")
            traceback.print_exc()
            return None, None

    def _save_additional_info(self, dataset_type: str, additional_info: Dict):
        """Save additional processing information"""
        save_path = os.path.join(
            self.config.checkpoint_dir,
            f'{dataset_type}_additional_info.pt'
        )
        torch.save(additional_info, save_path)
        print(f"\nSaved additional processing info to {save_path}")


## Main Execution 1

In [None]:
def tpu_training_wrapper(func):
    """Wrapper to handle TPU execution modes"""
    def wrapper(*args, **kwargs):
        if xm.xrt_world_size() > 1:
            # Multi-core TPU execution
            xmp.spawn(func, args=args, nprocs=8)
        else:
            # Single core execution
            return func(*args, **kwargs)
    return wrapper

@tpu_training_wrapper
def main_single_dataset():
    """Main function to process datasets one at a time with TPU support"""
    try:
        # Setup with TPU awareness
        device = setup_tpu()  # Use the setup_tpu function we defined earlier
        config = ModelConfig()

        if not setup_system():
            raise RuntimeError("System setup failed")

        # Dataset configurations
        datasets = {
            'cic': {'path': '/content/CIC_IoT_M3.csv', 'description': 'CIC-IDS Dataset'},
            'ton': {'path': '/content/UNSW_TON_IoT.csv', 'description': 'TON-IoT Dataset'},
            'cse': {'path': '/content/CSE-CIC_2018.csv', 'description': 'CSE-CIC Dataset'}
        }

        # Process one dataset at a time with TPU handling
        results = {}
        for dataset_type, info in datasets.items():
            print(f"\n{'='*50}")
            print(f"Processing {dataset_type.upper()} Dataset")
            print(f"{'='*50}")

            try:
                trainer = SingleDatasetTrainer(config, dataset_type)
                
                # Wrap model training in TPU execution context if using TPU
                if 'xla' in str(device):
                    print("Using TPU for training")
                    with xm.master_print_every_n_sec():
                        model, metrics = trainer.train_on_dataset(info['path'])
                else:
                    model, metrics = trainer.train_on_dataset(info['path'])

                results[dataset_type] = {
                    'model': model,
                    'metrics': metrics
                }

                print(f"\nCompleted processing {dataset_type.upper()}")
                print("Cleaning up memory...")
                gc.collect()
                
                # TPU-specific cleanup
                if 'xla' in str(device):
                    xm.mark_step()
                    
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

            except Exception as e:
                print(f"Error processing {dataset_type} dataset: {str(e)}")
                continue

        return results

    except Exception as e:
        print(f"Error in main execution: {str(e)}")
        traceback.print_exc()
        return None

if __name__ == "__main__":
    # TPU setup verification
    import torch_xla.debug.metrics as met
    
    try:
        main_single_dataset()
        
        # Print TPU metrics if available
        if xm.xrt_world_size() > 0:
            print("\nTPU Metrics:")
            print(met.metrics_report())
    except Exception as e:
        print(f"Error in TPU execution: {str(e)}")
        traceback.print_exc()



---

