**1. Environment Setup:** Install necessary dependencies. We use pycryptodome for the secure Keccak-512 hashing algorithm and GPUtil for monitoring GPU resources.

In [None]:
# Install necessary libraries for Hashing and GPU monitoring
!pip install GPUtil
!pip install pycryptodome

Collecting GPUtil
  Downloading GPUtil-1.4.0.tar.gz (5.5 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: GPUtil
  Building wheel for GPUtil (setup.py) ... [?25l[?25hdone
  Created wheel for GPUtil: filename=GPUtil-1.4.0-py3-none-any.whl size=7392 sha256=2b8e44bfa064c1c37a08112b65025259136ba713d4ee7f8157e503da15c0e357
  Stored in directory: /root/.cache/pip/wheels/92/a8/b7/d8a067c31a74de9ca252bbe53dea5f896faabd25d55f541037
Successfully built GPUtil
Installing collected packages: GPUtil
Successfully installed GPUtil-1.4.0
Collecting pycryptodome
  Downloading pycryptodome-3.23.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.4 kB)
Downloading pycryptodome-3.23.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m84.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pycryptodome
Successfully installed py

**2. Imports & Configuration:** Import standard libraries for deep learning (TensorFlow/Keras), data manipulation, and cryptography.

In [None]:
import os
import time
import random
import pickle
import psutil
import subprocess
import numpy as np
import pandas as pd
import tensorflow as tf
import seaborn as sns
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import (
    classification_report, accuracy_score, precision_score,
    recall_score, f1_score, confusion_matrix
)
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, BatchNormalization, Input
from Crypto.Hash import keccak

# ==========================================
# 1. CONFIGURATION & HYPERPARAMETERS
# ==========================================

# Federated Learning Settings
TOTAL_CLIENTS = 10
ROUNDS = 10
EPOCHS = 5
BATCH_SIZE = 32

# Attack Configuration
# Note: In a 10-client simulation, we inject 1 of each attacker type.
DATA_POISONING_CLIENTS = 1   # Attacks by flipping labels (High Loss)
MODEL_POISONING_CLIENTS = 1  # Attacks by adding noise to weights (High Weight Divergence)
TAMPERING_CLIENTS = 1        # Valid clients whose updates are corrupted in transit (Integrity Fail)

# Detection Thresholds
# Statistical multipliers for outlier detection (e.g., 2.0x the median)
LOSS_MULTIPLIER = 2.0
WEIGHT_MULTIPLIER = 2.0

# Dataset Configuration
# UPDATE THIS PATH to your local or Drive file location
DATASET_PATH = '/content/drive/MyDrive/IDS-IOT2024/Process_1 IDS-IoT-2024.csv'
# DATASET_PATH = 'Process_1 IDS-IoT-2024.csv'

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

print("Configuration complete.")

Configuration complete.


**3. Utility Functions (Metrics & Hashing):** Define helper functions to monitor system resources (CPU/GPU memory), compute model size, and perform the core cryptographic operations: hash_weights (Keccak-512) and verify_hash.

In [None]:

def get_sys_memory():
    """Returns current process RAM usage in GB."""
    process = psutil.Process()
    return process.memory_info().rss / (1024 ** 3)

def get_gpu_memory():
    """Returns GPU memory usage in GB via nvidia-smi."""
    try:
        result = subprocess.check_output(
            ['nvidia-smi', '--query-gpu=memory.used', '--format=csv,nounits,noheader']
        )
        return float(result.decode('utf-8').strip()) / 1024
    except:
        return 0.0

def hash_weights(weights):
    """
    Computes the Keccak-512 hash of the model weights.
    Returns: Hex digest string.
    """
    weights_serialized = pickle.dumps(weights)
    hash_obj = keccak.new(digest_bits=512)
    hash_obj.update(weights_serialized)
    return hash_obj.hexdigest()

def verify_hash(weights, received_hash):
    """
    Verifies if the computed hash of 'weights' matches 'received_hash'.
    Returns: Boolean (True if valid).
    """
    return hash_weights(weights) == received_hash

def get_model_size_kb(model):
    """Calculates model size in Kilobytes (KB)."""
    # Assuming float32 (4 bytes per parameter)
    bytes_size = sum(np.prod(w.shape) for w in model.get_weights()) * 4
    return bytes_size / 1024

def get_communication_size_mb(model):
    """Calculates model size in Megabytes (MB) for communication overhead."""
    return get_model_size_kb(model) / 1024

**4. Data Loading & Preprocessing:** Load the IDSIoT2024 dataset, encode labels, and split the data into training (for clients) and testing (for server evaluation) sets.

In [None]:

def load_and_preprocess_data():
    print(f"[*] Loading dataset: {DATASET_PATH}")
    try:
        df = pd.read_csv(DATASET_PATH)
        print(f"    - Shape: {df.shape}")

        X = df.iloc[:, :-1].values
        y = df['Attack_Category_x'].values

        # Encode Labels
        encoder = LabelEncoder()
        y_encoded = encoder.fit_transform(y)
        num_classes = len(np.unique(y_encoded))

        # Split Data (80% Global Train (distributed to clients), 20% Global Test)
        X_train, X_test, y_train, y_test = train_test_split(
            X, y_encoded, test_size=0.2, random_state=SEED
        )
        return X_train, X_test, y_train, y_test, num_classes, encoder
    except FileNotFoundError:
        print(f"[!] Error: File not found at {DATASET_PATH}")
        return None

**5. Model Architecture & Client Setup:** Define the lightweight Feed-Forward Neural Network (FFNN) architecture. We also create the distribute_data_and_assign_roles function to partition data among clients and assign specific attack roles (Data Poisoning, Model Poisoning, Tampering).

In [None]:
def create_lightweight_model(input_dim, num_classes):
    """
    Defines the Lightweight FFNN architecture from the paper.
    Structure: Input -> Dense(128) -> BN -> Dropout -> Dense(64) -> BN -> Dropout -> Output
    """
    model = Sequential([
        Input(shape=(input_dim,)),
        Dense(128, activation='relu'),
        BatchNormalization(),
        Dropout(0.2),
        Dense(64, activation='relu'),
        BatchNormalization(),
        Dropout(0.2),
        Dense(num_classes, activation='softmax')
    ])
    model.compile(optimizer='AdamW', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

def distribute_data_and_assign_roles(X, y):
    """
    Splits data among clients and assigns attack roles.
    Returns: client_data list and sets of attacker indices.
    """
    client_data = []
    indices = list(range(TOTAL_CLIENTS))

    # Assign Roles
    malicious = random.sample(indices, DATA_POISONING_CLIENTS + MODEL_POISONING_CLIENTS)
    data_poisoners = set(malicious[:DATA_POISONING_CLIENTS])
    model_poisoners = set(malicious[DATA_POISONING_CLIENTS:])

    valid_pool = [i for i in indices if i not in malicious]
    tamper_victims = set(random.sample(valid_pool, TAMPERING_CLIENTS))

    print(f"[*] Role Assignment:")
    print(f"    - Data Poisoners: {data_poisoners}")
    print(f"    - Model Poisoners: {model_poisoners}")
    print(f"    - Tampering Victims: {tamper_victims}")

    # Distribute Data
    for i in indices:
        # IID Split for simplicity
        X_c, _, y_c, _ = train_test_split(X, y, test_size=0.9, random_state=i) # Small subset per client

        # Apply Data Poisoning (Label Flipping) immediately if applicable
        if i in data_poisoners:
            y_c = np.roll(y_c, shift=1, axis=0) # Shift labels cyclically

        client_data.append((X_c, y_c))

    return client_data, data_poisoners, model_poisoners, tamper_victims

**6. Server Detection Logic:** Implement the statistical detection mechanism (The Poisoning Gate). This function analyzes client losses and weight updates to identify outliers, flagging them as potential Data or Model poisoning attacks.

In [None]:
def detect_poisoning(client_losses, weight_diffs, round_num):
    """
    Statistical Outlier Detection.
    Identifies clients whose Loss or Weight Updates deviate significantly from the median.
    """
    detected = []

    # Calculate medians (Robust to outliers)
    med_loss = np.median(client_losses) if client_losses else 0
    med_diff = np.median([np.mean(w) for w in weight_diffs]) if weight_diffs else 0

    # Dynamic thresholding: We ignore detection in Round 1 (Warm-up)
    if round_num == 0:
        return []

    for i, (loss, diff_list) in enumerate(zip(client_losses, weight_diffs)):
        avg_diff = np.mean(diff_list)

        # 1. Data Poisoning Check (Loss >> Median)
        if loss > (med_loss * LOSS_MULTIPLIER) and loss > 0.5: # 0.5 is a safe floor
            detected.append((i, 'Data Poisoning'))

        # 2. Model Poisoning Check (Update Dist >> Median)
        elif avg_diff > (med_diff * WEIGHT_MULTIPLIER) and avg_diff > 0.1:
            detected.append((i, 'Model Poisoning'))

    return detected

**7. Main Training Simulation:** The core simulation loop. It orchestrates the entire Integrity-Driven Federated Learning process:


1.   Client Training: Local training, hash computation, and attack simulation (MITM/Tampering).
2.   Server Verification: Integrity check (Gate 1) -> Poisoning detection (Gate 2) -> Aggregation.
3.   Computing global accuracy, throughput, overheads, and resource usage.


In [None]:
# 1. Init Data
data_pack = load_and_preprocess_data()
if not data_pack:
    print("Simulation stopped due to missing data.")
else:
    X_train, X_test, y_train, y_test, n_classes, encoder = data_pack

    # 2. Init Clients & Roles
    client_data, _, model_poisoners, tamper_victims = distribute_data_and_assign_roles(X_train, y_train)

    # 3. Init Models
    # Global Model (Server)
    global_model = create_lightweight_model(X_train.shape[1], n_classes)
    # Reusable Client Model (Optimization: Prevents memory leaks/freezing)
    client_worker = create_lightweight_model(X_train.shape[1], n_classes)

    model_size_mb = get_communication_size_mb(global_model)
    model_size_kb = get_model_size_kb(global_model)

    # 4. Metric Logs
    history = {
        'acc': [], 'f1': [], 'prec': [], 'rec': [],
        'throughput_samples': [], 'throughput_mbs': [],
        'comp_time': [], 'comm_overhead': [],
        'hash_comp': [], 'hash_verif': [],
        'cpu': [], 'gpu': []
    }

    excluded_clients = set()

    print("\n" + "="*50)
    print("STARTING FEDERATED LEARNING SIMULATION")
    print(f"Clients: {TOTAL_CLIENTS} | Rounds: {ROUNDS}")
    print("="*50)

    for r in range(ROUNDS):
        round_start = time.time()
        print(f"\n--- Round {r+1}/{ROUNDS} ---")

        # Buffers for server-side processing
        updates_buffer = []

        # --- CLIENT PHASE ---
        for cid, (X_c, y_c) in enumerate(client_data):
            if cid in excluded_clients:
                continue

            # 1. Download Global Weights
            client_worker.set_weights(global_model.get_weights())

            # 2. Local Training (Measure Computation Overhead)
            t_start_train = time.time()
            hist = client_worker.fit(X_c, y_c, epochs=EPOCHS, batch_size=BATCH_SIZE, verbose=0)
            t_train_dur = time.time() - t_start_train

            # Get Parameters
            weights = client_worker.get_weights()
            loss = hist.history['loss'][-1]

            # Apply Model Poisoning (Noise Injection) if applicable
            if cid in model_poisoners:
                weights = [w + np.random.normal(0, 0.5, w.shape) for w in weights]

            # 3. Hash Computation (Measure Time)
            t_start_hash = time.time()
            c_hash = hash_weights(weights)
            t_hash_dur = time.time() - t_start_hash

            # 4. Transmission Simulation (MITM Tampering)
            tx_weights = weights
            if cid in tamper_victims:
                # Modify weights in transit (Hash will NOT match)
                tx_weights = [w + np.random.normal(0, 0.01, w.shape) for w in weights]

            updates_buffer.append({
                'id': cid,
                'weights': tx_weights,
                'hash': c_hash,
                'loss': loss,
                'comp_overhead': t_train_dur, # Paper defines computation overhead as training time
                'hash_overhead': t_hash_dur
            })

            # Simple progress indicator
            print(".", end="", flush=True)

        print(" [Upload Complete]")

        # --- SERVER PHASE ---

        valid_updates = []
        round_losses = []
        round_diffs = []
        temp_map = [] # Maps index back to updates_buffer

        hash_verif_times = []

        # Gate 1: Integrity Verification
        for update in updates_buffer:
            t_v_start = time.time()
            is_valid = verify_hash(update['weights'], update['hash'])
            hash_verif_times.append(time.time() - t_v_start)

            if not is_valid:
                # Dropped silently or logged
                # print(f"    [!] Integrity Check Failed: Client {update['id']}")
                continue

            # Prepare for Poisoning Check
            # Calculate L2 norm distance from global model
            diff = [np.linalg.norm(update['weights'][i] - global_model.get_weights()[i])
                    for i in range(len(update['weights']))]

            round_losses.append(update['loss'])
            round_diffs.append(diff)
            temp_map.append(update)

        # Gate 2: Poisoning Detection
        detected_attacks = detect_poisoning(round_losses, round_diffs, r)
        poison_indices = [x[0] for x in detected_attacks]

        # Log and Exclude
        for local_idx, reason in detected_attacks:
            real_id = temp_map[local_idx]['id']
            if real_id not in excluded_clients:
                print(f"    [!] DETECTED Client {real_id}: {reason}. Banning.")
                excluded_clients.add(real_id)

        # Gate 3: Aggregation
        for i, update in enumerate(temp_map):
            if i not in poison_indices:
                valid_updates.append(update['weights'])

        if valid_updates:
            # FedAvg
            new_weights = [np.mean([w[i] for w in valid_updates], axis=0)
                           for i in range(len(valid_updates[0]))]
            global_model.set_weights(new_weights)
            print(f"    > Aggregated {len(valid_updates)} updates.")
        else:
            print("    > No valid updates to aggregate.")

        # --- EVALUATION PHASE ---

        # 1. Classification Metrics
        y_pred = np.argmax(global_model.predict(X_test, verbose=0), axis=1)

        acc = accuracy_score(y_test, y_pred)
        prec = precision_score(y_test, y_pred, average='weighted', zero_division=0)
        rec = recall_score(y_test, y_pred, average='weighted', zero_division=0)
        f1 = f1_score(y_test, y_pred, average='weighted', zero_division=0)

        # 2. Performance Metrics
        r_time = time.time() - round_start
        throughput_samples = len(X_train) / r_time # Samples processed per second

        # Communication Overhead: Size of model * number of clients who sent updates
        comm_overhead_mb = model_size_mb * len(updates_buffer)

        # Throughput in Mbs (Megabits per second)
        throughput_mbs = (comm_overhead_mb * 8) / r_time

        # Avg Computation Overhead (Training time per client)
        avg_comp_time = np.mean([u['comp_overhead'] for u in updates_buffer]) * 1000 # ms

        # Hash Times (seconds)
        avg_hash_comp_sec = np.mean([u['hash_overhead'] for u in updates_buffer])
        avg_hash_verif_sec = np.mean(hash_verif_times) if hash_verif_times else 0

        # Resources
        cpu_usage = get_sys_memory()
        gpu_usage = get_gpu_memory()

        # Store
        history['acc'].append(acc)
        history['f1'].append(f1)
        history['prec'].append(prec)
        history['rec'].append(rec)
        history['throughput_samples'].append(throughput_samples)
        history['throughput_mbs'].append(throughput_mbs)
        history['comp_time'].append(avg_comp_time)
        history['comm_overhead'].append(comm_overhead_mb)
        history['hash_comp'].append(avg_hash_comp_sec)
        history['hash_verif'].append(avg_hash_verif_sec)
        history['cpu'].append(cpu_usage)
        history['gpu'].append(gpu_usage)

        # Print Round Summary
        print(f"    Accuracy: {acc:.4f} | F1: {f1:.4f}")
        print(f"    Throughput: {throughput_samples:.2f} samples/sec | {throughput_mbs:.4f} Mbps")
        print(f"    Avg Comp Overhead: {avg_comp_time:.2f} ms")
        print(f"    Avg Hash Gen Time: {avg_hash_comp_sec:.6f} s | Verif Time: {avg_hash_verif_sec:.6f} s")
        print(f"    CPU Usage: {cpu_usage:.2f} GB | GPU Usage: {gpu_usage:.4f} GB")