In [1]:
import torch
import numpy as np
import pandas as pd
import torch.nn.functional as F
import os

In [2]:
def load_csv_data(input_folder: str,
                  file_name: str):
    """
    Reads train & test CSVs from disk.
    
    Returns:
      train_df (both pandas.DataFrame)
    """
    dataset_path = os.path.join(input_folder, file_name)
    dataset_df = pd.read_csv(dataset_path)
    return dataset_df

def extract_features_labels(df: pd.DataFrame):
    """
    Splits a DataFrame into numpy feature array X and label vector y.
    
    The last column is the label.
    """
    X = df.iloc[:, :-1].values
    y = df.iloc[:,  -1].values
    return X, y

class DataLoader(object):
    def __init__(self,
                 data,
                 labels,
                 batch_size=1,
                 shuffle=True):
        self.data = data
        self.labels = labels
        self.batch_size = batch_size
        self.shuffle = shuffle

    def __len__(self):
        return int(np.ceil(self.data.shape[0] / self.batch_size))

    def __iter__(self):
        n = self.data.shape[0]
        idxlist = list(range(n))
        if self.shuffle:
            np.random.shuffle(idxlist)

        for _, start_idx in enumerate(range(0, n, self.batch_size)):
            end_idx = min(start_idx + self.batch_size, n)
            data = self.data[idxlist[start_idx:end_idx]]
            labels = self.labels[idxlist[start_idx:end_idx]]
            ############################################################
            # Check if any class is missing in the batch
            # present_classes = np.unique(labels.cpu().numpy())
            # all_classes = np.arange(len(label_mapping))  # Adjust based on number of classes
            # missing_classes = set(all_classes) - set(present_classes)
            #
            # if missing_classes:
            #     print(f"Batch {start_idx // self.batch_size} is missing classes {missing_classes}")
            ############################################################
            yield data, labels

In [3]:
# label_L1_mapping = {"MQTT": 0, "Benign": 1} 
# label_L2_mapping = {"MQTT-DDoS-Connect_Flood": 0, "MQTT-DDoS-Publish_Flood": 1, 
#                     "MQTT-DoS-Connect_Flood": 2, "MQTT-DoS-Publish_Flood": 3,
#                     "MQTT-Malformed_Data": 4, "Benign": 5} 

In [4]:
# Load data
input_folder = '/home/zyang44/Github/baseline_cicIOT/P1_structurelevel/efficiency/input_files'
test_fname = 'logiKNet_test_3994.csv'

test_df = load_csv_data(input_folder, test_fname)

# Count how many rows in each class
class_counts = test_df.iloc[:, -1].value_counts()
print("Class counts in the test set:")
print(class_counts)

# Get whole indices of the benign class
# stratified draw 5 indices from each other class
benign_indices = test_df[test_df.iloc[:, -1] == 5].index.tolist()
other_classes = test_df[test_df.iloc[:, -1] != 5].iloc[:, -1].unique()
stratified_indices = []
for cls in other_classes:
    cls_indices = test_df[test_df.iloc[:, -1] == cls].index.tolist()
    if len(cls_indices) >= 5:
        stratified_indices.extend(np.random.choice(cls_indices, 5, replace=False).tolist())
    else:
        stratified_indices.extend(cls_indices)

print(len(benign_indices), "benign indices")
print(len(stratified_indices), "stratified indices from other classes")


Class counts in the test set:
label_L2
5    703
0    684
1    673
3    650
4    645
2    639
Name: count, dtype: int64
703 benign indices
25 stratified indices from other classes


In [16]:
# IF READ THE TEST SET SEQUENTIALLY,
# we can randomly put the benign class (label_L2 = 5) in the dataset

# shuffle the test set
test_df = test_df.sample(frac=1, random_state=42).reset_index(drop=True)
# Extract features and labels from the test DataFrame
X_test, y_test = extract_features_labels(test_df)
# Create DataLoader for test data
test_loader = DataLoader(
    data=torch.tensor(X_test, dtype=torch.float32),
    labels=torch.tensor(y_test, dtype=torch.long),
    batch_size=1,
    shuffle=False
    )

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [18]:
from kan import KAN
from utils import MLP

def load_model_state(infer_model, model_save_folder, model_name):
    """
    Load the model from disk.
    """
    checkpoint = torch.load(
        os.path.join(model_save_folder, model_name),
        map_location=device,
        weights_only=True     # <-- only load tensor weights, no pickle objects
    )
    infer_model.load_state_dict(checkpoint)
    infer_model.eval()
    return infer_model
###############################load model and testing########################################
model_state_folder = '/home/zyang44/Github/baseline_cicIOT/P1_structurelevel/efficiency/model_weights'

# load all four models
mlp_infer = MLP(layer_sizes=(18, 10, 6)).to(device)
mlp_infer = load_model_state(mlp_infer, model_state_folder, 'mlp.pt')

logicmlp_infer = MLP(layer_sizes=(18, 10, 6)).to(device)
logicmlp_infer = load_model_state(logicmlp_infer, model_state_folder, 'logic_mlp.pt')

logiKNet_infer = KAN(width=[18, 10, 6], grid=5, k=3, seed=42, device=device)
logiKNet_infer = load_model_state(logiKNet_infer, model_state_folder, 'logiKNet.pt')

hierarchical_logiKNet_infer = KAN(width=[18, 10, 6], grid=5, k=3, seed=42, device=device)
hierarchical_logiKNet_infer = load_model_state(hierarchical_logiKNet_infer, model_state_folder, 'hierarchical_logiKNet.pt')


checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0


In [19]:
model = logiKNet_infer  # Choose the model you want to evaluate

model.eval()
batch_times = []

with torch.no_grad():
    for data, labels in test_loader:
        logits = model(data)
        preds = torch.argmax(logits, dim=1)
        print(f"Predictions: {preds.cpu().numpy()}, Labels: {labels.cpu().numpy()}")

  self.subnode_actscale.append(torch.std(x, dim=0).detach())
  input_range = torch.std(preacts, dim=0) + 0.1
  output_range_spline = torch.std(postacts_numerical, dim=0) # for training, only penalize the spline part
  output_range = torch.std(postacts, dim=0) # for visualization, include the contribution from both spline + symbolic


Predictions: [3], Labels: [3]
Predictions: [3], Labels: [3]
Predictions: [4], Labels: [4]
Predictions: [0], Labels: [2]
Predictions: [1], Labels: [1]
Predictions: [5], Labels: [5]
Predictions: [2], Labels: [2]
Predictions: [3], Labels: [1]
Predictions: [1], Labels: [1]
Predictions: [5], Labels: [5]
Predictions: [0], Labels: [0]
Predictions: [2], Labels: [0]
Predictions: [5], Labels: [5]
Predictions: [5], Labels: [5]
Predictions: [4], Labels: [4]
Predictions: [2], Labels: [0]
Predictions: [2], Labels: [2]
Predictions: [3], Labels: [3]
Predictions: [2], Labels: [2]
Predictions: [3], Labels: [1]
Predictions: [3], Labels: [3]
Predictions: [3], Labels: [3]
Predictions: [3], Labels: [3]
Predictions: [4], Labels: [4]
Predictions: [3], Labels: [3]
Predictions: [2], Labels: [2]
Predictions: [5], Labels: [5]
Predictions: [3], Labels: [3]
Predictions: [3], Labels: [3]
Predictions: [3], Labels: [3]
Predictions: [4], Labels: [4]
Predictions: [5], Labels: [4]
Predictions: [0], Labels: [0]
Prediction

In [12]:
import numpy as np
import time

def simulate_packet_arrivals_real_time(lam=100, interval=1, duration=10):
    """
    Simulate packet arrivals in real time over a duration using a Poisson distribution.
    :param lam: Average arrival rate per second (packets per second).
    :param interval: Length of each time interval in seconds.
    :param duration: Total number of intervals (seconds) to simulate.
    """
    for i in range(duration):
        packets = np.random.poisson(lam * interval)
        print(f"Interval {i+1}: {packets} packets")
        if i < duration - 1:
            time.sleep(interval)

# Run the simulation for 10 seconds, printing one value per second
simulate_packet_arrivals_real_time()


Interval 1: 86 packets
Interval 2: 95 packets
Interval 3: 103 packets
Interval 4: 101 packets
Interval 5: 94 packets
Interval 6: 105 packets
Interval 7: 90 packets
Interval 8: 93 packets
Interval 9: 104 packets
Interval 10: 98 packets
