In [8]:
import pandas as pd
import numpy as np
from sklearn.utils import shuffle
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import sys

In [9]:
data_path = 'F:/Documents/CRCE/Project/NIDS/dataset/Edge-IIoT/Edge-IIoTset dataset/Selected dataset for ML and DL/DNN-EdgeIIoT-dataset.csv'  # Replace with your actual path
df = pd.read_csv(data_path, low_memory=False)
df.head()

Unnamed: 0,frame.time,ip.src_host,ip.dst_host,arp.dst.proto_ipv4,arp.opcode,arp.hw.size,arp.src.proto_ipv4,icmp.checksum,icmp.seq_le,icmp.transmit_timestamp,...,mqtt.proto_len,mqtt.protoname,mqtt.topic,mqtt.topic_len,mqtt.ver,mbtcp.len,mbtcp.trans_id,mbtcp.unit_id,Attack_label,Attack_type
0,2021 11:44:10.081753000,192.168.0.128,192.168.0.101,0,0.0,0.0,0,0.0,0.0,0.0,...,0.0,0,0,0.0,0.0,0.0,0.0,0.0,0,Normal
1,2021 11:44:10.162218000,192.168.0.101,192.168.0.128,0,0.0,0.0,0,0.0,0.0,0.0,...,4.0,MQTT,0,0.0,4.0,0.0,0.0,0.0,0,Normal
2,2021 11:44:10.162271000,192.168.0.128,192.168.0.101,0,0.0,0.0,0,0.0,0.0,0.0,...,0.0,0,0,0.0,0.0,0.0,0.0,0.0,0,Normal
3,2021 11:44:10.162641000,192.168.0.128,192.168.0.101,0,0.0,0.0,0,0.0,0.0,0.0,...,0.0,0,0,0.0,0.0,0.0,0.0,0.0,0,Normal
4,2021 11:44:10.166132000,192.168.0.101,192.168.0.128,0,0.0,0.0,0,0.0,0.0,0.0,...,0.0,0,Temperature_and_Humidity,24.0,0.0,0.0,0.0,0.0,0,Normal


In [10]:
# -------------------- 1. Preprocessing -------------------- #
# Drop unnecessary columns
drop_columns = [
    "frame.time", "ip.src_host", "ip.dst_host", "arp.src.proto_ipv4", "arp.dst.proto_ipv4",
    "http.file_data", "http.request.full_uri", "icmp.transmit_timestamp", "http.request.uri.query",
    "tcp.options", "tcp.payload", "tcp.srcport", "tcp.dstport", "udp.port", "mqtt.msg"
]
df.drop(drop_columns, axis=1, inplace=True, errors='ignore')

# Remove rows with NaN or duplicates
df.dropna(axis=0, how='any', inplace=True)
df.drop_duplicates(subset=None, keep="first", inplace=True)

# Remove leading/trailing spaces from strings
df = df.map(lambda x: x.strip() if isinstance(x, str) else x)

# Shuffle data
df = shuffle(df)

# Helper function to one-hot encode categorical columns
def encode_text_dummy(df, name):
    dummies = pd.get_dummies(df[name], prefix=name)
    df = pd.concat([df, dummies], axis=1)
    return df.drop(name, axis=1)

# Encode categorical features
for col in ['http.request.method', 'http.referer', 'http.request.version',
            'dns.qry.name.len', 'mqtt.conack.flags', 'mqtt.protoname', 'mqtt.topic']:
    if col in df.columns:
        df = encode_text_dummy(df, col)

# Drop any remaining non-numeric columns
# df = df.select_dtypes(include=[np.number])

# Save 'Attack_type' before dropping non-numeric columns
labels = df['Attack_type']

# Drop any remaining non-numeric columns (preserve only numeric features)
df = df.select_dtypes(include=[np.number]).copy()

# Add 'Attack_type' column back
df['Attack_type'] = labels

In [3]:
df.columns

Index(['frame.time', 'ip.src_host', 'ip.dst_host', 'arp.dst.proto_ipv4',
       'arp.opcode', 'arp.hw.size', 'arp.src.proto_ipv4', 'icmp.checksum',
       'icmp.seq_le', 'icmp.transmit_timestamp', 'icmp.unused',
       'http.file_data', 'http.content_length', 'http.request.uri.query',
       'http.request.method', 'http.referer', 'http.request.full_uri',
       'http.request.version', 'http.response', 'http.tls_port', 'tcp.ack',
       'tcp.ack_raw', 'tcp.checksum', 'tcp.connection.fin',
       'tcp.connection.rst', 'tcp.connection.syn', 'tcp.connection.synack',
       'tcp.dstport', 'tcp.flags', 'tcp.flags.ack', 'tcp.len', 'tcp.options',
       'tcp.payload', 'tcp.seq', 'tcp.srcport', 'udp.port', 'udp.stream',
       'udp.time_delta', 'dns.qry.name', 'dns.qry.name.len', 'dns.qry.qu',
       'dns.qry.type', 'dns.retransmission', 'dns.retransmit_request',
       'dns.retransmit_request_in', 'mqtt.conack.flags',
       'mqtt.conflag.cleansess', 'mqtt.conflags', 'mqtt.hdrflags', 'mqtt.len

In [11]:
# -------------------- 2. Feature Prep -------------------- #
# Separate labels and features
if 'Attack_type' in df.columns:
    labels = df['Attack_type']
    df.drop(['Attack_type'], axis=1, inplace=True)
else:
    raise ValueError("Attack_type column is missing.")

# Optional: also drop 'Attack_label' if still present
df.drop(['Attack_label'], axis=1, inplace=True, errors='ignore')

# Encode target labels
le = LabelEncoder()
labels_encoded = le.fit_transform(labels)

# Scale features
scaler = StandardScaler()
features_scaled = scaler.fit_transform(df)

# Convert to tensors
X = torch.tensor(features_scaled, dtype=torch.float32).unsqueeze(1)  # Add channel dim
y = torch.tensor(labels_encoded, dtype=torch.long)


In [12]:
# -------------------- 3. Data Split -------------------- #
X_trainval, X_test, y_trainval, y_test = train_test_split(X, y, test_size=0.2, stratify=y, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_trainval, y_trainval, test_size=0.25, stratify=y_trainval, random_state=42)

batch_size = 64
train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=batch_size, shuffle=True)
val_loader = DataLoader(TensorDataset(X_val, y_val), batch_size=batch_size)
test_loader = DataLoader(TensorDataset(X_test, y_test), batch_size=batch_size)


In [13]:
import joblib

# Save label encoder
joblib.dump(le, "pso_label_encoder.pkl")

# Save the scaler
joblib.dump(scaler, "pso_scaler.pkl")

# Save feature column order
joblib.dump(df.columns.tolist(), "pso_feature_columns.pkl")


['pso_feature_columns.pkl']

In [14]:
# -------------------- 4. CNN Model -------------------- #
num_class = len(le.classes_)

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv1d(1, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(64, 32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv1d(32, 16, kernel_size=3, padding=1)        
        self.fc1 = nn.Linear(X.shape[2] * 16, 30)
        self.out_layer = nn.Linear(30, num_class)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.max_pool1d(x, kernel_size=1)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return F.log_softmax(self.out_layer(x), dim=1)

In [15]:
def get_flat_params(model):
    return torch.cat([p.data.view(-1) for p in model.parameters()])

def set_flat_params(model, flat_params):
    idx = 0
    for p in model.parameters():
        param_length = p.numel()
        p.data = flat_params[idx:idx+param_length].view(p.size()).to(p.device)
        idx += param_length


In [20]:
import random

n_particles = 30
w = 0.7       # inertia
c1 = 1.5      # personal influence
c2 = 1.5      # social influence
iterations = 50

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = CNN().to(device)
dim = get_flat_params(model).shape[0]

# Initialize particles
particles = [torch.randn(dim).to(device) for _ in range(n_particles)]
velocities = [torch.randn(dim).to(device) * 0.1 for _ in range(n_particles)]
p_best = particles.copy()
p_best_scores = [float('inf')] * n_particles
g_best = None
g_best_score = float('inf')


Define fitness function

In [21]:
def evaluate_particle(model, flat_params):
    set_flat_params(model, flat_params)
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            val_loss += loss.item() * data.size(0)
    return val_loss / len(val_loader.dataset)


In [None]:
for iter in range(iterations):
    print(f"\n=== PSO Iteration {iter+1}/{iterations} ===")
    
    for i in range(n_particles):
        score = evaluate_particle(model, particles[i])
        print(f"Particle {i+1}/{n_particles} - Validation Loss: {score:.4f}")

        # Personal best update
        if score < p_best_scores[i]:
            p_best[i] = particles[i].clone()
            p_best_scores[i] = score

        # Global best update
        if score < g_best_score:
            g_best = particles[i].clone()
            g_best_score = score

    # Velocity and position update
    for i in range(n_particles):
        r1 = torch.rand(dim).to(device)
        r2 = torch.rand(dim).to(device)

        velocities[i] = w * velocities[i] + \
                        c1 * r1 * (p_best[i] - particles[i]) + \
                        c2 * r2 * (g_best - particles[i])
        particles[i] = particles[i] + velocities[i]

    # Set the best weights so far
    set_flat_params(model, g_best)

    # Evaluate on validation set with best weights
    model.eval()
    val_preds, val_true = [], []
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            val_preds.extend(pred.cpu().numpy())
            val_true.extend(target.cpu().numpy())

    val_acc = accuracy_score(val_true, val_preds)
    print(f">>> Iteration {iter+1} - Best Validation Accuracy: {val_acc:.4f}")
    print(">>> Classification Report:")
    print(classification_report(val_true, val_preds, target_names=le.classes_))



=== PSO Iteration 1/50 ===
Particle 1/30 - Validation Loss: 3856.6492
Particle 2/30 - Validation Loss: 7576.0907
Particle 3/30 - Validation Loss: 10181.1167
Particle 4/30 - Validation Loss: 11378.7802
Particle 5/30 - Validation Loss: 7421.9290
Particle 6/30 - Validation Loss: 10839.3071
Particle 7/30 - Validation Loss: 5190.7240
Particle 8/30 - Validation Loss: 7869.1969
Particle 9/30 - Validation Loss: 12122.1263
Particle 10/30 - Validation Loss: 12138.9122
Particle 11/30 - Validation Loss: 3988.0061
Particle 12/30 - Validation Loss: 10720.9254
Particle 13/30 - Validation Loss: 7652.6237
Particle 14/30 - Validation Loss: 10573.7629
Particle 15/30 - Validation Loss: 10189.3674
Particle 16/30 - Validation Loss: 19204.0801
Particle 17/30 - Validation Loss: 12783.8969
Particle 18/30 - Validation Loss: 12948.4242
Particle 19/30 - Validation Loss: 19850.0982
Particle 20/30 - Validation Loss: 9861.2953
Particle 21/30 - Validation Loss: 5440.6133
Particle 22/30 - Validation Loss: 11256.9245


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                       precision    recall  f1-score   support

             Backdoor       0.01      0.04      0.01      4805
            DDoS_HTTP       0.00      0.00      0.00      9709
            DDoS_ICMP       0.00      0.00      0.00     13588
             DDoS_TCP       0.00      0.00      0.00     10012
             DDoS_UDP       0.00      0.00      0.00     24313
       Fingerprinting       0.00      0.10      0.00       171
                 MITM       0.00      0.00      0.00        72
               Normal       0.79      0.46      0.58    272799
             Password       0.01      0.01      0.01      9987
        Port_Scanning       0.01      0.48      0.03      3995
           Ransomware       0.50      0.00      0.00      1938
        SQL_injection       0.00      0.00      0.00     10165
            Uploading       0.12      0.07      0.09      7362
Vulnerability_scanner       0.00      0.00      0.00     10005
                  XSS       0.00      0.00      0.00  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                       precision    recall  f1-score   support

             Backdoor       0.01      0.04      0.01      4805
            DDoS_HTTP       0.00      0.00      0.00      9709
            DDoS_ICMP       0.00      0.00      0.00     13588
             DDoS_TCP       0.00      0.00      0.00     10012
             DDoS_UDP       0.00      0.00      0.00     24313
       Fingerprinting       0.00      0.10      0.00       171
                 MITM       0.00      0.00      0.00        72
               Normal       0.79      0.46      0.58    272799
             Password       0.01      0.01      0.01      9987
        Port_Scanning       0.01      0.48      0.03      3995
           Ransomware       0.50      0.00      0.00      1938
        SQL_injection       0.00      0.00      0.00     10165
            Uploading       0.12      0.07      0.09      7362
Vulnerability_scanner       0.00      0.00      0.00     10005
                  XSS       0.00      0.00      0.00  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                       precision    recall  f1-score   support

             Backdoor       0.01      0.04      0.01      4805
            DDoS_HTTP       0.00      0.00      0.00      9709
            DDoS_ICMP       0.00      0.00      0.00     13588
             DDoS_TCP       0.91      0.02      0.03     10012
             DDoS_UDP       0.00      0.00      0.00     24313
       Fingerprinting       0.00      0.10      0.00       171
                 MITM       0.00      0.00      0.00        72
               Normal       0.79      0.42      0.55    272799
             Password       0.00      0.00      0.00      9987
        Port_Scanning       0.02      0.48      0.03      3995
           Ransomware       0.00      0.00      0.00      1938
        SQL_injection       0.00      0.00      0.00     10165
            Uploading       0.14      0.09      0.11      7362
Vulnerability_scanner       0.00      0.00      0.00     10005
                  XSS       0.01      0.03      0.02  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                       precision    recall  f1-score   support

             Backdoor       0.02      0.02      0.02      4805
            DDoS_HTTP       0.03      0.97      0.06      9709
            DDoS_ICMP       0.00      0.00      0.00     13588
             DDoS_TCP       0.00      0.00      0.00     10012
             DDoS_UDP       0.00      0.00      0.00     24313
       Fingerprinting       0.00      0.00      0.00       171
                 MITM       0.00      0.00      0.00        72
               Normal       0.35      0.05      0.09    272799
             Password       0.00      0.00      0.00      9987
        Port_Scanning       0.00      0.00      0.00      3995
           Ransomware       0.00      0.00      0.00      1938
        SQL_injection       0.00      0.00      0.00     10165
            Uploading       0.03      0.06      0.04      7362
Vulnerability_scanner       0.00      0.00      0.00     10005
                  XSS       0.00      0.00      0.00  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                       precision    recall  f1-score   support

             Backdoor       0.02      0.02      0.02      4805
            DDoS_HTTP       0.03      0.97      0.06      9709
            DDoS_ICMP       0.00      0.00      0.00     13588
             DDoS_TCP       0.00      0.00      0.00     10012
             DDoS_UDP       0.00      0.00      0.00     24313
       Fingerprinting       0.00      0.00      0.00       171
                 MITM       0.00      0.00      0.00        72
               Normal       0.35      0.05      0.09    272799
             Password       0.00      0.00      0.00      9987
        Port_Scanning       0.00      0.00      0.00      3995
           Ransomware       0.00      0.00      0.00      1938
        SQL_injection       0.00      0.00      0.00     10165
            Uploading       0.03      0.06      0.04      7362
Vulnerability_scanner       0.00      0.00      0.00     10005
                  XSS       0.00      0.00      0.00  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                       precision    recall  f1-score   support

             Backdoor       0.02      0.02      0.02      4805
            DDoS_HTTP       0.03      0.97      0.06      9709
            DDoS_ICMP       0.00      0.00      0.00     13588
             DDoS_TCP       0.00      0.00      0.00     10012
             DDoS_UDP       0.00      0.00      0.00     24313
       Fingerprinting       0.00      0.00      0.00       171
                 MITM       0.00      0.00      0.00        72
               Normal       0.35      0.05      0.09    272799
             Password       0.00      0.00      0.00      9987
        Port_Scanning       0.00      0.00      0.00      3995
           Ransomware       0.00      0.00      0.00      1938
        SQL_injection       0.00      0.00      0.00     10165
            Uploading       0.03      0.06      0.04      7362
Vulnerability_scanner       0.00      0.00      0.00     10005
                  XSS       0.00      0.00      0.00  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                       precision    recall  f1-score   support

             Backdoor       0.02      0.02      0.02      4805
            DDoS_HTTP       0.03      0.97      0.06      9709
            DDoS_ICMP       0.00      0.00      0.00     13588
             DDoS_TCP       0.00      0.00      0.00     10012
             DDoS_UDP       0.00      0.00      0.00     24313
       Fingerprinting       0.00      0.00      0.00       171
                 MITM       0.00      0.00      0.00        72
               Normal       0.35      0.05      0.09    272799
             Password       0.00      0.00      0.00      9987
        Port_Scanning       0.00      0.00      0.00      3995
           Ransomware       0.00      0.00      0.00      1938
        SQL_injection       0.00      0.00      0.00     10165
            Uploading       0.03      0.06      0.04      7362
Vulnerability_scanner       0.00      0.00      0.00     10005
                  XSS       0.00      0.00      0.00  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                       precision    recall  f1-score   support

             Backdoor       0.02      0.02      0.02      4805
            DDoS_HTTP       0.03      0.97      0.06      9709
            DDoS_ICMP       0.00      0.00      0.00     13588
             DDoS_TCP       0.00      0.00      0.00     10012
             DDoS_UDP       0.00      0.00      0.00     24313
       Fingerprinting       0.00      0.00      0.00       171
                 MITM       0.00      0.00      0.00        72
               Normal       0.35      0.05      0.09    272799
             Password       0.00      0.00      0.00      9987
        Port_Scanning       0.00      0.00      0.00      3995
           Ransomware       0.00      0.00      0.00      1938
        SQL_injection       0.00      0.00      0.00     10165
            Uploading       0.03      0.06      0.04      7362
Vulnerability_scanner       0.00      0.00      0.00     10005
                  XSS       0.00      0.00      0.00  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                       precision    recall  f1-score   support

             Backdoor       0.02      0.02      0.02      4805
            DDoS_HTTP       0.03      0.97      0.06      9709
            DDoS_ICMP       0.00      0.00      0.00     13588
             DDoS_TCP       0.00      0.00      0.00     10012
             DDoS_UDP       0.00      0.00      0.00     24313
       Fingerprinting       0.00      0.00      0.00       171
                 MITM       0.00      0.00      0.00        72
               Normal       0.35      0.05      0.09    272799
             Password       0.00      0.00      0.00      9987
        Port_Scanning       0.00      0.00      0.00      3995
           Ransomware       0.00      0.00      0.00      1938
        SQL_injection       0.00      0.00      0.00     10165
            Uploading       0.03      0.06      0.04      7362
Vulnerability_scanner       0.00      0.00      0.00     10005
                  XSS       0.00      0.00      0.00  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Particle 1/30 - Validation Loss: 7747.0010
Particle 2/30 - Validation Loss: 17883.2663
Particle 3/30 - Validation Loss: 47822.2105
Particle 4/30 - Validation Loss: 12023.8629
Particle 5/30 - Validation Loss: 22564.2324
Particle 6/30 - Validation Loss: 25358.2373
Particle 7/30 - Validation Loss: 14164.5602
Particle 8/30 - Validation Loss: 21417.5130
Particle 9/30 - Validation Loss: 13333.1834
Particle 10/30 - Validation Loss: 29874.6527
Particle 11/30 - Validation Loss: 9607.8469
Particle 12/30 - Validation Loss: 8215.4429
Particle 13/30 - Validation Loss: 15989.5000
Particle 14/30 - Validation Loss: 9918.9281
Particle 15/30 - Validation Loss: 17660.0411
Particle 16/30 - Validation Loss: 12219.9672
Particle 17/30 - Validation Loss: 3193.4491
Particle 18/30 - Validation Loss: 23565.7888
Particle 19/30 - Validation Loss: 27849.8920
Particle 20/30 - Validation Loss: 10794.7188
Particle 21/30 - Validation Loss: 14746.1061
Particle 22/30 - Validation Loss: 12856.6200
Particle 23/30 - Validat

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                       precision    recall  f1-score   support

             Backdoor       0.02      0.02      0.02      4805
            DDoS_HTTP       0.03      0.97      0.06      9709
            DDoS_ICMP       0.00      0.00      0.00     13588
             DDoS_TCP       0.00      0.00      0.00     10012
             DDoS_UDP       0.00      0.00      0.00     24313
       Fingerprinting       0.00      0.00      0.00       171
                 MITM       0.00      0.00      0.00        72
               Normal       0.35      0.05      0.09    272799
             Password       0.00      0.00      0.00      9987
        Port_Scanning       0.00      0.00      0.00      3995
           Ransomware       0.00      0.00      0.00      1938
        SQL_injection       0.00      0.00      0.00     10165
            Uploading       0.03      0.06      0.04      7362
Vulnerability_scanner       0.00      0.00      0.00     10005
                  XSS       0.00      0.00      0.00  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                       precision    recall  f1-score   support

             Backdoor       0.02      0.02      0.02      4805
            DDoS_HTTP       0.03      0.97      0.06      9709
            DDoS_ICMP       0.00      0.00      0.00     13588
             DDoS_TCP       0.00      0.00      0.00     10012
             DDoS_UDP       0.00      0.00      0.00     24313
       Fingerprinting       0.00      0.00      0.00       171
                 MITM       0.00      0.00      0.00        72
               Normal       0.35      0.05      0.09    272799
             Password       0.00      0.00      0.00      9987
        Port_Scanning       0.00      0.00      0.00      3995
           Ransomware       0.00      0.00      0.00      1938
        SQL_injection       0.00      0.00      0.00     10165
            Uploading       0.03      0.06      0.04      7362
Vulnerability_scanner       0.00      0.00      0.00     10005
                  XSS       0.00      0.00      0.00  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                       precision    recall  f1-score   support

             Backdoor       0.02      0.02      0.02      4805
            DDoS_HTTP       0.03      0.97      0.06      9709
            DDoS_ICMP       0.00      0.00      0.00     13588
             DDoS_TCP       0.00      0.00      0.00     10012
             DDoS_UDP       0.00      0.00      0.00     24313
       Fingerprinting       0.00      0.00      0.00       171
                 MITM       0.00      0.00      0.00        72
               Normal       0.35      0.05      0.09    272799
             Password       0.00      0.00      0.00      9987
        Port_Scanning       0.00      0.00      0.00      3995
           Ransomware       0.00      0.00      0.00      1938
        SQL_injection       0.00      0.00      0.00     10165
            Uploading       0.03      0.06      0.04      7362
Vulnerability_scanner       0.00      0.00      0.00     10005
                  XSS       0.00      0.00      0.00  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Particle 1/30 - Validation Loss: 11917.6686
Particle 2/30 - Validation Loss: 6814.6971
Particle 3/30 - Validation Loss: 31656.8551
Particle 4/30 - Validation Loss: 19228.0182
Particle 5/30 - Validation Loss: 12239.8694
Particle 6/30 - Validation Loss: 14107.2701
Particle 7/30 - Validation Loss: 8558.7861
Particle 8/30 - Validation Loss: 23961.0942
Particle 9/30 - Validation Loss: 4425.5631
Particle 10/30 - Validation Loss: 14517.6514
Particle 11/30 - Validation Loss: 34957.8024
Particle 12/30 - Validation Loss: 5394.3806
Particle 13/30 - Validation Loss: 74834.3199
Particle 14/30 - Validation Loss: 11925.0652
Particle 15/30 - Validation Loss: 26225.2880
Particle 16/30 - Validation Loss: 13883.4626
Particle 17/30 - Validation Loss: 7991.7535
Particle 18/30 - Validation Loss: 32960.4467
Particle 19/30 - Validation Loss: 71422.7998
Particle 20/30 - Validation Loss: 29551.5815
Particle 21/30 - Validation Loss: 22054.0947
Particle 22/30 - Validation Loss: 6878.8898
Particle 23/30 - Validati

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Particle 1/30 - Validation Loss: 8727.6287
Particle 2/30 - Validation Loss: 3175.2121
Particle 3/30 - Validation Loss: 68299.8380
Particle 4/30 - Validation Loss: 34977.5807
Particle 5/30 - Validation Loss: 19142.6563
Particle 6/30 - Validation Loss: 19023.7691
Particle 7/30 - Validation Loss: 7030.3035
Particle 8/30 - Validation Loss: 115307.8877
Particle 9/30 - Validation Loss: 5650.6184
Particle 10/30 - Validation Loss: 40358.9627
Particle 11/30 - Validation Loss: 5090.3222
Particle 12/30 - Validation Loss: 3466.4746
Particle 13/30 - Validation Loss: 24234.0623
Particle 14/30 - Validation Loss: 16906.1115
Particle 15/30 - Validation Loss: 13884.5193
Particle 16/30 - Validation Loss: 8776.3827
Particle 17/30 - Validation Loss: 3634.6908
Particle 18/30 - Validation Loss: 22263.6285
Particle 19/30 - Validation Loss: 34631.8131
Particle 20/30 - Validation Loss: 27754.6451
Particle 21/30 - Validation Loss: 7843.3452
Particle 22/30 - Validation Loss: 10372.7905
Particle 23/30 - Validation

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                       precision    recall  f1-score   support

             Backdoor       0.02      0.02      0.02      4805
            DDoS_HTTP       0.03      0.97      0.06      9709
            DDoS_ICMP       0.00      0.00      0.00     13588
             DDoS_TCP       0.00      0.00      0.00     10012
             DDoS_UDP       0.00      0.00      0.00     24313
       Fingerprinting       0.00      0.00      0.00       171
                 MITM       0.00      0.00      0.00        72
               Normal       0.35      0.05      0.09    272799
             Password       0.00      0.00      0.00      9987
        Port_Scanning       0.00      0.00      0.00      3995
           Ransomware       0.00      0.00      0.00      1938
        SQL_injection       0.00      0.00      0.00     10165
            Uploading       0.03      0.06      0.04      7362
Vulnerability_scanner       0.00      0.00      0.00     10005
                  XSS       0.00      0.00      0.00  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Particle 1/30 - Validation Loss: 21521.8130
Particle 2/30 - Validation Loss: 4250.8949
Particle 3/30 - Validation Loss: 33736.7065
Particle 4/30 - Validation Loss: 32609.8314
Particle 5/30 - Validation Loss: 8080.4688
Particle 6/30 - Validation Loss: 18454.7879
Particle 7/30 - Validation Loss: 12408.2037
Particle 8/30 - Validation Loss: 18102.4305
Particle 9/30 - Validation Loss: 14315.8436
Particle 10/30 - Validation Loss: 43252.9771
Particle 11/30 - Validation Loss: 26396.1235
Particle 12/30 - Validation Loss: 8146.0037
Particle 13/30 - Validation Loss: 36583.0671
Particle 14/30 - Validation Loss: 15432.2851
Particle 15/30 - Validation Loss: 33308.4263
Particle 16/30 - Validation Loss: 10231.2991
Particle 17/30 - Validation Loss: 7842.8643
Particle 18/30 - Validation Loss: 20226.8568
Particle 19/30 - Validation Loss: 15994.2406
Particle 20/30 - Validation Loss: 20593.5137
Particle 21/30 - Validation Loss: 23549.7500
Particle 22/30 - Validation Loss: 35549.7366
Particle 23/30 - Valida

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Particle 1/30 - Validation Loss: 51371.1414
Particle 2/30 - Validation Loss: 5841.1540
Particle 3/30 - Validation Loss: 32449.3931
Particle 4/30 - Validation Loss: 18479.1668
Particle 5/30 - Validation Loss: 17678.5631
Particle 6/30 - Validation Loss: 31751.9094
Particle 7/30 - Validation Loss: 6717.5732
Particle 8/30 - Validation Loss: 29054.2297
Particle 9/30 - Validation Loss: 15574.4945
Particle 10/30 - Validation Loss: 16708.5769
Particle 11/30 - Validation Loss: 15928.6277
Particle 12/30 - Validation Loss: 4323.6353
Particle 13/30 - Validation Loss: 56362.0871
Particle 14/30 - Validation Loss: 27413.4334
Particle 15/30 - Validation Loss: 16601.5356
Particle 16/30 - Validation Loss: 11469.4263
Particle 17/30 - Validation Loss: 12025.9076
Particle 18/30 - Validation Loss: 21931.5010
Particle 19/30 - Validation Loss: 49456.1472
Particle 20/30 - Validation Loss: 22855.5223
Particle 21/30 - Validation Loss: 8741.0535
Particle 22/30 - Validation Loss: 26198.2599
Particle 23/30 - Valida

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Particle 1/30 - Validation Loss: 6166.9851
Particle 2/30 - Validation Loss: 5093.4265
Particle 3/30 - Validation Loss: 95258.3358
Particle 4/30 - Validation Loss: 47508.8821
Particle 5/30 - Validation Loss: 12547.9746
Particle 6/30 - Validation Loss: 18249.0930
Particle 7/30 - Validation Loss: 12967.5839
Particle 8/30 - Validation Loss: 33693.6014
Particle 9/30 - Validation Loss: 9884.6707
Particle 10/30 - Validation Loss: 31707.8843
Particle 11/30 - Validation Loss: 13547.2056
Particle 12/30 - Validation Loss: 2186.3778
Particle 13/30 - Validation Loss: 31299.0306
Particle 14/30 - Validation Loss: 38630.1898
Particle 15/30 - Validation Loss: 7072.3846
Particle 16/30 - Validation Loss: 19009.5083
Particle 17/30 - Validation Loss: 14031.0186
Particle 18/30 - Validation Loss: 40857.8084
Particle 19/30 - Validation Loss: 37281.1267
Particle 20/30 - Validation Loss: 18459.5711
Particle 21/30 - Validation Loss: 12020.1748
Particle 22/30 - Validation Loss: 4876.6442
Particle 23/30 - Validati

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Particle 1/30 - Validation Loss: 11627.1495
Particle 2/30 - Validation Loss: 4320.9344
Particle 3/30 - Validation Loss: 61185.8088
Particle 4/30 - Validation Loss: 67314.3527
Particle 5/30 - Validation Loss: 13099.9228
Particle 6/30 - Validation Loss: 6642.4717
Particle 7/30 - Validation Loss: 11634.6453
Particle 8/30 - Validation Loss: 93495.3498
Particle 9/30 - Validation Loss: 10917.2530
Particle 10/30 - Validation Loss: 40480.4574
Particle 11/30 - Validation Loss: 8029.6583
Particle 12/30 - Validation Loss: 3747.0557
Particle 13/30 - Validation Loss: 28178.2702
Particle 14/30 - Validation Loss: 62152.8115
Particle 15/30 - Validation Loss: 14304.3803
Particle 16/30 - Validation Loss: 21639.9160
Particle 17/30 - Validation Loss: 15028.7987
Particle 18/30 - Validation Loss: 30671.8279
Particle 19/30 - Validation Loss: 21482.1785
Particle 20/30 - Validation Loss: 37850.0344
Particle 21/30 - Validation Loss: 5479.1081
Particle 22/30 - Validation Loss: 18776.1454
Particle 23/30 - Validat

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                       precision    recall  f1-score   support

             Backdoor       0.02      0.07      0.03      4805
            DDoS_HTTP       0.02      0.40      0.05      9709
            DDoS_ICMP       0.00      0.00      0.00     13588
             DDoS_TCP       0.00      0.00      0.00     10012
             DDoS_UDP       0.00      0.00      0.00     24313
       Fingerprinting       0.00      0.00      0.00       171
                 MITM       0.00      0.00      0.00        72
               Normal       0.64      0.37      0.47    272799
             Password       0.00      0.00      0.00      9987
        Port_Scanning       0.00      0.02      0.01      3995
           Ransomware       0.00      0.00      0.00      1938
        SQL_injection       0.00      0.00      0.00     10165
            Uploading       0.01      0.00      0.00      7362
Vulnerability_scanner       0.00      0.00      0.00     10005
                  XSS       0.00      0.00      0.00  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                       precision    recall  f1-score   support

             Backdoor       0.02      0.07      0.03      4805
            DDoS_HTTP       0.02      0.40      0.05      9709
            DDoS_ICMP       0.00      0.00      0.00     13588
             DDoS_TCP       0.00      0.00      0.00     10012
             DDoS_UDP       0.00      0.00      0.00     24313
       Fingerprinting       0.00      0.00      0.00       171
                 MITM       0.00      0.00      0.00        72
               Normal       0.64      0.37      0.47    272799
             Password       0.00      0.00      0.00      9987
        Port_Scanning       0.00      0.02      0.01      3995
           Ransomware       0.00      0.00      0.00      1938
        SQL_injection       0.00      0.00      0.00     10165
            Uploading       0.01      0.00      0.00      7362
Vulnerability_scanner       0.00      0.00      0.00     10005
                  XSS       0.00      0.00      0.00  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                       precision    recall  f1-score   support

             Backdoor       0.02      0.07      0.03      4805
            DDoS_HTTP       0.02      0.40      0.05      9709
            DDoS_ICMP       0.00      0.00      0.00     13588
             DDoS_TCP       0.00      0.00      0.00     10012
             DDoS_UDP       0.00      0.00      0.00     24313
       Fingerprinting       0.00      0.00      0.00       171
                 MITM       0.00      0.00      0.00        72
               Normal       0.64      0.37      0.47    272799
             Password       0.00      0.00      0.00      9987
        Port_Scanning       0.00      0.02      0.01      3995
           Ransomware       0.00      0.00      0.00      1938
        SQL_injection       0.00      0.00      0.00     10165
            Uploading       0.01      0.00      0.00      7362
Vulnerability_scanner       0.00      0.00      0.00     10005
                  XSS       0.00      0.00      0.00  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                       precision    recall  f1-score   support

             Backdoor       0.02      0.07      0.03      4805
            DDoS_HTTP       0.02      0.40      0.05      9709
            DDoS_ICMP       0.00      0.00      0.00     13588
             DDoS_TCP       0.00      0.00      0.00     10012
             DDoS_UDP       0.00      0.00      0.00     24313
       Fingerprinting       0.00      0.00      0.00       171
                 MITM       0.00      0.00      0.00        72
               Normal       0.64      0.37      0.47    272799
             Password       0.00      0.00      0.00      9987
        Port_Scanning       0.00      0.02      0.01      3995
           Ransomware       0.00      0.00      0.00      1938
        SQL_injection       0.00      0.00      0.00     10165
            Uploading       0.01      0.00      0.00      7362
Vulnerability_scanner       0.00      0.00      0.00     10005
                  XSS       0.00      0.00      0.00  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                       precision    recall  f1-score   support

             Backdoor       0.02      0.07      0.03      4805
            DDoS_HTTP       0.02      0.40      0.05      9709
            DDoS_ICMP       0.00      0.00      0.00     13588
             DDoS_TCP       0.00      0.00      0.00     10012
             DDoS_UDP       0.00      0.00      0.00     24313
       Fingerprinting       0.00      0.00      0.00       171
                 MITM       0.00      0.00      0.00        72
               Normal       0.64      0.37      0.47    272799
             Password       0.00      0.00      0.00      9987
        Port_Scanning       0.00      0.02      0.01      3995
           Ransomware       0.00      0.00      0.00      1938
        SQL_injection       0.00      0.00      0.00     10165
            Uploading       0.01      0.00      0.00      7362
Vulnerability_scanner       0.00      0.00      0.00     10005
                  XSS       0.00      0.00      0.00  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                       precision    recall  f1-score   support

             Backdoor       0.02      0.07      0.03      4805
            DDoS_HTTP       0.02      0.40      0.05      9709
            DDoS_ICMP       0.00      0.00      0.00     13588
             DDoS_TCP       0.00      0.00      0.00     10012
             DDoS_UDP       0.00      0.00      0.00     24313
       Fingerprinting       0.00      0.00      0.00       171
                 MITM       0.00      0.00      0.00        72
               Normal       0.64      0.37      0.47    272799
             Password       0.00      0.00      0.00      9987
        Port_Scanning       0.00      0.02      0.01      3995
           Ransomware       0.00      0.00      0.00      1938
        SQL_injection       0.00      0.00      0.00     10165
            Uploading       0.01      0.00      0.00      7362
Vulnerability_scanner       0.00      0.00      0.00     10005
                  XSS       0.00      0.00      0.00  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                       precision    recall  f1-score   support

             Backdoor       0.00      0.00      0.00      4805
            DDoS_HTTP       0.02      0.23      0.04      9709
            DDoS_ICMP       0.00      0.00      0.00     13588
             DDoS_TCP       0.00      0.00      0.00     10012
             DDoS_UDP       0.00      0.00      0.00     24313
       Fingerprinting       0.00      0.00      0.00       171
                 MITM       0.00      0.00      0.00        72
               Normal       0.70      0.58      0.64    272799
             Password       0.00      0.00      0.00      9987
        Port_Scanning       0.00      0.00      0.00      3995
           Ransomware       0.00      0.00      0.00      1938
        SQL_injection       0.00      0.00      0.00     10165
            Uploading       0.10      0.18      0.13      7362
Vulnerability_scanner       0.00      0.00      0.00     10005
                  XSS       0.00      0.00      0.00  

In [None]:
# -------------------- 6. Final Evaluation -------------------- #
model.eval()
test_preds, test_true = [], []
with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        pred = output.argmax(dim=1)
        test_preds.extend(pred.cpu().numpy())
        test_true.extend(target.cpu().numpy())

test_acc = accuracy_score(test_true, test_preds)
print(f"\nFinal Test Accuracy: {test_acc:.4f}")
print("Test Classification Report:\n", classification_report(test_true, test_preds, target_names=le.classes_))

In [13]:
len(le.classes_)

15

In [12]:
torch.save(model.state_dict(), "tuesday_model.pth")

In [14]:
X.shape[2]

39

In [19]:
df.columns

Index(['arp.opcode', 'arp.hw.size', 'icmp.checksum', 'icmp.seq_le',
       'icmp.unused', 'http.content_length', 'http.response', 'http.tls_port',
       'tcp.ack', 'tcp.ack_raw', 'tcp.checksum', 'tcp.connection.fin',
       'tcp.connection.rst', 'tcp.connection.syn', 'tcp.connection.synack',
       'tcp.flags', 'tcp.flags.ack', 'tcp.len', 'tcp.seq', 'udp.stream',
       'udp.time_delta', 'dns.qry.name', 'dns.qry.qu', 'dns.qry.type',
       'dns.retransmission', 'dns.retransmit_request',
       'dns.retransmit_request_in', 'mqtt.conflag.cleansess', 'mqtt.conflags',
       'mqtt.hdrflags', 'mqtt.len', 'mqtt.msg_decoded_as', 'mqtt.msgtype',
       'mqtt.proto_len', 'mqtt.topic_len', 'mqtt.ver', 'mbtcp.len',
       'mbtcp.trans_id', 'mbtcp.unit_id'],
      dtype='object')