**1. Imports and Setup**

In [3]:
import pandas as pd
import numpy as np
import time
import pickle
import os
import random
import lightgbm as lgb
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# Set random seeds for reproducibility
np.random.seed(42)
random.seed(42)

print("Libraries loaded. Setup complete.")

Libraries loaded. Setup complete.


**2. Security & Aggregation Classes:** This cell implements the core components defined in your paper: Schnorr ZKP, Laplace DP, and Median Aggregation.

In [4]:
# --- 1. Zero-Knowledge Proof (Schnorr Protocol) ---
# Client Authentication
class SchnorrProtocol:
    def __init__(self, p, g):
        self.p = p
        self.g = g
        self.private_key = np.random.randint(1, p)
        self.public_key = pow(g, self.private_key, p)

    def generate_commitment(self):
        self.r = np.random.randint(1, self.p)
        self.R = pow(self.g, self.r, self.p)
        return self.R

    def compute_response(self, challenge):
        self.s = (self.r + challenge * self.private_key) % (self.p - 1)
        return self.s

    def verify(self, R, s, public_key, challenge):
        lhs = pow(self.g, s, self.p)
        rhs = (R * pow(public_key, challenge, self.p)) % self.p
        return lhs == rhs

# --- 2. Differential Privacy (Laplace Mechanism) ---
# Implements Section II-B: Privacy Preservation on Predictions
def apply_laplace_noise(values, epsilon, sensitivity=1.0):
    if epsilon <= 0: return values
    scale = sensitivity / epsilon
    noise = np.random.laplace(0, scale, size=values.shape)
    return values + noise

# --- 3. Robust Aggregation (Median) ---
# Implements Section II-F: Mitigation of Poisoning Attacks
def robust_aggregation(updates):
    return np.median(updates, axis=0)

**3. Data Loading & Preprocessing**

In [5]:
# Update this path to your specific dataset location
dataset_path = '/content/drive/MyDrive/IDS-IOT2024/Process_1 IDS-IoT-2024.csv'

# Check if file exists to prevent errors
if not os.path.exists(dataset_path):
    print(f"Error: Dataset not found at {dataset_path}. Please upload the file.")
else:
    df = pd.read_csv(dataset_path)

    # Feature and Target Extraction
    X = df.iloc[:, :-1].values
    y = df['Attack_Category_x'].values

    # Encode target variable
    label_encoder = LabelEncoder()
    y_encoded = label_encoder.fit_transform(y)

    # Split data (Keeping test set separate for global evaluation)
    X_train_full, X_test, y_train_full, y_test = train_test_split(
        X, y_encoded, test_size=0.2, random_state=42
    )

    # Shuffle for IID simulation
    shuffled_indices = np.random.permutation(len(X_train_full))
    X_train_shuffled = X_train_full[shuffled_indices]
    y_train_shuffled = y_train_full[shuffled_indices]

    print("Data loaded and preprocessed successfully.")

Data loaded and preprocessed successfully.


**4. Main Experiment Loop**

In [None]:
# --- Simulation Parameters ---
client_counts = [10, 25, 50, 75, 100]
#client_counts = [10]
rounds = 10
privacy_budget = 1.5  # Epsilon

# Schnorr Parameters
p = 104729
g = 2

# LightGBM Parameters
lgb_params = {
    'objective': 'multiclass',
    'num_class': len(np.unique(y_encoded)),
    'boosting_type': 'gbdt',
    'metric': 'multi_logloss',
    'learning_rate': 0.1,
    'num_leaves': 31,
    'max_depth': -1,
    'verbosity': -1
}

results = []

for client_count in client_counts:
    print(f"\nRunning experiment for {client_count} clients...")

    # 1. Distribute Data
    client_data = []
    samples_per_client = len(X_train_shuffled) // client_count
    for i in range(client_count):
        start_idx = i * samples_per_client
        end_idx = start_idx + samples_per_client if i < client_count - 1 else len(X_train_shuffled)
        X_client = X_train_shuffled[start_idx:end_idx]
        y_client = y_train_shuffled[start_idx:end_idx]
        client_data.append((X_client, y_client))

    # 2. Assign Adversarial Roles
    # Type A: Unverified Clients (Fail ZKP)
    num_unverified = min(2, client_count)
    unverified_clients = random.sample(range(client_count), num_unverified)

    # Type B: Insider Poisoners (Pass ZKP, but send bad updates)
    remaining_clients = [c for c in range(client_count) if c not in unverified_clients]
    num_poisoners = min(3, len(remaining_clients))
    poisoning_clients = random.sample(remaining_clients, num_poisoners)

    # Reset Metrics
    excluded_clients = 0
    local_accuracies = []
    global_accuracies = []
    client_auth_times = []
    client_train_times = []
    round_latencies = []
    aggregation_times = []
    communication_overheads = []

    # Initialize Protocols
    client_protocols = [SchnorrProtocol(p, g) for _ in range(client_count)]

    # --- Federated Rounds ---
    for round_num in range(rounds):
        print(f"  Round {round_num + 1}/{rounds}")
        round_start_time = time.time()

        local_models = []          # For saving model size later
        local_predictions = []     # For aggregation
        round_client_auth_times = []
        round_client_train_times = []
        round_comm_overhead = 0

        # --- Client Phase ---
        for client_num, (X_c, y_c) in enumerate(client_data):

            # A. Authentication (ZKP)
            auth_start_time = time.time()
            schnorr = client_protocols[client_num]
            R = schnorr.generate_commitment()
            challenge = np.random.randint(1, p)

            # Simulating ZKP Failure for Unverified Clients
            if client_num in unverified_clients:
                response = schnorr.compute_response(challenge) + 1 # Invalid
            else:
                response = schnorr.compute_response(challenge) # Valid

            # Verify
            if not schnorr.verify(R, response, schnorr.public_key, challenge):
                # print(f"    Client {client_num + 1} failed authentication.")
                excluded_clients += 1
                auth_end_time = time.time()
                round_client_auth_times.append(auth_end_time - auth_start_time)
                continue # Skip training for unverified clients

            auth_end_time = time.time()
            round_client_auth_times.append(auth_end_time - auth_start_time)

            # B. Local Training
            train_start_time = time.time()

            # Simulating Data Poisoning (Insiders)
            if client_num in poisoning_clients:
                # Add noise to features
                X_train_curr = X_c + np.random.normal(0, 2.0, X_c.shape)
            else:
                X_train_curr = X_c

            train_data = lgb.Dataset(X_train_curr, label=y_c)
            local_model = lgb.train(lgb_params, train_data, num_boost_round=20)

            train_end_time = time.time()
            round_client_train_times.append(train_end_time - train_start_time)

            # Measure Model Size for Communication Overhead
            local_model_bytes = pickle.dumps(local_model)
            round_comm_overhead += len(local_model_bytes)
            local_models.append(local_model)

            # Predictions
            preds = local_model.predict(X_test)

            # Simulating Model Poisoning (Insiders send garbage predictions)
            if client_num in poisoning_clients:
                preds = np.random.rand(*preds.shape) # Random noise predictions

            # C. Differential Privacy
            noisy_preds = apply_laplace_noise(preds, epsilon=privacy_budget)
            local_predictions.append(noisy_preds)

            # Local Accuracy Tracking
            local_pred_labels = np.argmax(noisy_preds, axis=1)
            local_acc = accuracy_score(y_test, local_pred_labels)
            local_accuracies.append(local_acc)

        # --- Server Phase ---
        if local_predictions:
            # D. Robust Aggregation
            agg_start_time = time.time()
            aggregated_preds = robust_aggregation(local_predictions)

            # Global Evaluation
            y_pred_global = np.argmax(aggregated_preds, axis=1)
            global_acc = accuracy_score(y_test, y_pred_global)
            global_accuracies.append(global_acc)

            agg_end_time = time.time()
            aggregation_times.append(agg_end_time - agg_start_time)

            # Measure Aggregated Model Size (Communication back to clients)
            agg_model_bytes = pickle.dumps(aggregated_preds)
            round_comm_overhead += len(agg_model_bytes)

            print(f"  Round {round_num + 1} Global Accuracy: {global_acc:.4f}")
        else:
            print("  No valid updates received.")

        # Round Metrics
        communication_overheads.append(round_comm_overhead)
        round_latencies.append(time.time() - round_start_time)
        client_auth_times.extend(round_client_auth_times)
        client_train_times.extend(round_client_train_times)

    # Save Global Model to calculate exact file size (MB)
    if local_models:
        model_filename = f"global_model_{client_count}.pkl"
        with open(model_filename, "wb") as f:
            pickle.dump(local_models[0], f)
        global_model_size = os.path.getsize(model_filename) / (1024 ** 2) # MB
    else:
        global_model_size = 0

    # Store Final Results
    results.append({
        'clients': client_count,
        'avg_global_accuracy': np.mean(global_accuracies) if global_accuracies else 0,
        'avg_local_accuracy': np.mean(local_accuracies) if local_accuracies else 0,
        'avg_train_time': np.mean(client_train_times) if client_train_times else 0,
        'avg_auth_time': np.mean(client_auth_times) if client_auth_times else 0,
        'avg_round_latency': np.mean(round_latencies) if round_latencies else 0,
        'avg_aggregation_time': np.mean(aggregation_times) if aggregation_times else 0,
        'avg_comm_overhead': np.mean(communication_overheads) / (1024 ** 2) if communication_overheads else 0,
        'global_model_size': global_model_size,
    })

**5. Final Output Printing**

In [None]:
# Print all results
print("\nExperiment Results:")
for result in results:
    print(f"Clients: {result['clients']}, "
          f"Avg Global Accuracy: {result['avg_global_accuracy']:.5f}, "
          f"Avg Local Accuracy: {result['avg_local_accuracy']:.5f}, "
          f"Avg Train Time: {result['avg_train_time']:.5f}s, "
          f"Avg Auth Time: {result['avg_auth_time']:.5f}s, "
          f"Avg Round Latency: {result['avg_round_latency']:.5f}s, "
          f"Avg Aggregation Time: {result['avg_aggregation_time']:.5f}s, "
          f"Avg Comm Overhead: {result['avg_comm_overhead']:.5f}MB, "
          f"Global Model Size: {result['global_model_size']:.5f}MB")